Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement AutoWrap and AutoWrapFull guides #2928

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/source/infer.autoguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ AutoStructured
:special-members: __call__
:show-inheritance:

AutoWrap
--------
.. autoclass:: pyro.infer.autoguide.AutoWrap
:members:
:undoc-members:
:special-members: _sample, _setup_prototype
:show-inheritance:

AutoWrapFull
------------
.. autoclass:: pyro.infer.autoguide.AutoWrapFull
:members:
:undoc-members:
:special-members: _sample, _setup_prototype
:show-inheritance:

.. _autoguide-initialization:

Initialization
Expand Down
3 changes: 3 additions & 0 deletions pyro/infer/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
init_to_value,
)
from pyro.infer.autoguide.utils import mean_field_entropy
from pyro.infer.autoguide.wrap import AutoWrap, AutoWrapFull

__all__ = [
"AutoCallable",
Expand All @@ -43,6 +44,8 @@
"AutoNormal",
"AutoNormalizingFlow",
"AutoStructured",
"AutoWrap",
"AutoWrapFull",
"init_to_feasible",
"init_to_generated",
"init_to_mean",
Expand Down
74 changes: 22 additions & 52 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,45 +34,15 @@ def model():
from pyro.distributions import constraints
from pyro.distributions.transforms import affine_autoregressive, iterated
from pyro.distributions.util import eye_like, is_identically_zero, sum_rightmost
from pyro.infer.autoguide.initialization import (
InitMessenger,
init_to_feasible,
init_to_median,
)
from pyro.infer.enum import config_enumerate
from pyro.infer.inspect import get_dependencies
from pyro.nn import PyroModule, PyroParam
from pyro.ops.hessian import hessian
from pyro.ops.tensor_utils import periodic_repeat
from pyro.poutine.util import site_is_subsample

from .utils import _product, helpful_support_errors


def _deep_setattr(obj, key, val):
"""
Set an attribute `key` on the object. If any of the prefix attributes do
not exist, they are set to :class:`~pyro.nn.PyroModule`.
"""

def _getattr(obj, attr):
obj_next = getattr(obj, attr, None)
if obj_next is not None:
return obj_next
setattr(obj, attr, PyroModule())
return getattr(obj, attr)

lpart, _, rpart = key.rpartition(".")
# Recursive getattr while setting any prefix attributes to PyroModule
if lpart:
obj = functools.reduce(_getattr, [obj] + lpart.split("."))
setattr(obj, rpart, val)


def _deep_getattr(obj, key):
for part in key.split("."):
obj = getattr(obj, part)
return obj
from .initialization import InitMessenger, init_to_feasible, init_to_median
from .utils import _product, deep_getattr, deep_setattr, helpful_support_errors


def prototype_hide_fn(msg):
Expand Down Expand Up @@ -392,7 +362,7 @@ def _setup_prototype(self, *args, **kwargs):

value = PyroParam(value, site["fn"].support, event_dim)
with helpful_support_errors(site):
_deep_setattr(self, name, value)
deep_setattr(self, name, value)

def forward(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -503,18 +473,18 @@ def _setup_prototype(self, *args, **kwargs):
init_loc = periodic_repeat(init_loc, full_size, dim).contiguous()
init_scale = torch.full_like(init_loc, self._init_scale)

_deep_setattr(
deep_setattr(
self.locs, name, PyroParam(init_loc, constraints.real, event_dim)
)
_deep_setattr(
deep_setattr(
self.scales,
name,
PyroParam(init_scale, self.scale_constraint, event_dim),
)

def _get_loc_and_scale(self, name):
site_loc = _deep_getattr(self.locs, name)
site_scale = _deep_getattr(self.scales, name)
site_loc = deep_getattr(self.locs, name)
site_scale = deep_getattr(self.scales, name)
return site_loc, site_scale

def forward(self, *args, **kwargs):
Expand Down Expand Up @@ -1250,7 +1220,7 @@ def _setup_prototype(self, *args, **kwargs):
for site, Dist, param_spec in self._discrete_sites:
name = site["name"]
for param_name, param_init, param_constraint in param_spec:
_deep_setattr(
deep_setattr(
self,
"{}_{}".format(name, param_name),
PyroParam(param_init, constraint=param_constraint),
Expand Down Expand Up @@ -1464,23 +1434,23 @@ def _setup_prototype(self, *args, **kwargs):
for name, site in sample_sites.items():
# Initialize location parameters.
init_loc = init_locs[name]
_deep_setattr(self.locs, name, PyroParam(init_loc))
deep_setattr(self.locs, name, PyroParam(init_loc))

# Initialize parameters of conditional distributions.
conditional = self.conditionals[name]
if callable(conditional):
_deep_setattr(self.conds, name, conditional)
deep_setattr(self.conds, name, conditional)
else:
if conditional not in ("delta", "normal", "mvn"):
raise ValueError(f"Unsupported conditional type: {conditional}")
if conditional in ("normal", "mvn"):
init_scale = torch.full_like(init_loc, self._init_scale)
_deep_setattr(
deep_setattr(
self.scales, name, PyroParam(init_scale, self.scale_constraint)
)
if conditional == "mvn":
init_scale_tril = eye_like(init_loc, init_loc.numel())
_deep_setattr(
deep_setattr(
self.scale_trils,
name,
PyroParam(init_scale_tril, self.scale_tril_constraint),
Expand All @@ -1489,7 +1459,7 @@ def _setup_prototype(self, *args, **kwargs):
# Initialize dependencies on upstream variables.
num_pending[name] = 0
deps = PyroModule()
_deep_setattr(self.deps, name, deps)
deep_setattr(self.deps, name, deps)
for upstream, dep in self.dependencies.get(name, {}).items():
assert upstream in sample_sites
children[upstream].append(name)
Expand All @@ -1501,7 +1471,7 @@ def _setup_prototype(self, *args, **kwargs):
raise ValueError(
f"Expected either the string 'linear' or a callable, but got {dep}"
)
_deep_setattr(deps, upstream, dep)
deep_setattr(deps, upstream, dep)

# Topologically sort sites.
# TODO should we choose a more optimal structure?
Expand Down Expand Up @@ -1543,11 +1513,11 @@ def get_deltas(self, save_params=None):

# Sample zero-mean blockwise independent Delta/Normal/MVN.
log_density = 0.0
loc = _deep_getattr(self.locs, name)
loc = deep_getattr(self.locs, name)
zero = torch.zeros_like(loc)
conditional = self.conditionals[name]
if callable(conditional):
aux_value = _deep_getattr(self.conds, name)()
aux_value = deep_getattr(self.conds, name)()
elif conditional == "delta":
aux_value = zero
elif conditional == "normal":
Expand All @@ -1556,7 +1526,7 @@ def get_deltas(self, save_params=None):
dist.Normal(zero, 1).to_event(1),
infer={"is_auxiliary": True},
)
scale = _deep_getattr(self.scales, name)
scale = deep_getattr(self.scales, name)
aux_value = aux_value * scale
if compute_density:
log_density = (-scale.log()).expand_as(aux_value)
Expand All @@ -1568,8 +1538,8 @@ def get_deltas(self, save_params=None):
dist.Normal(zero, 1).to_event(1),
infer={"is_auxiliary": True},
)
scale = _deep_getattr(self.scales, name)
scale_tril = _deep_getattr(self.scale_trils, name)
scale = deep_getattr(self.scales, name)
scale_tril = deep_getattr(self.scale_trils, name)
aux_value = aux_value @ scale_tril.T * scale
if compute_density:
log_density = (
Expand All @@ -1587,9 +1557,9 @@ def get_deltas(self, save_params=None):
# Note: these shear transforms have no effect on the Jacobian
# determinant, and can therefore be excluded from the log_density
# computation below, even for nonlinear dep().
deps = _deep_getattr(self.deps, name)
deps = deep_getattr(self.deps, name)
for upstream in self.dependencies.get(name, {}):
dep = _deep_getattr(deps, upstream)
dep = deep_getattr(deps, upstream)
aux_value = aux_value + dep(aux_values[upstream])
aux_values[name] = aux_value

Expand Down Expand Up @@ -1637,7 +1607,7 @@ def forward(self, *args, **kwargs):
def median(self, *args, **kwargs):
result = {}
for name, site in self._sorted_sites:
loc = _deep_getattr(self.locs, name).detach()
loc = deep_getattr(self.locs, name).detach()
shape = self._batch_shapes[name] + self._unconstrained_event_shapes[name]
loc = loc.reshape(shape)
result[name] = biject_to(site["fn"].support)(loc)
Expand Down
28 changes: 28 additions & 0 deletions pyro/infer/autoguide/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import functools
from contextlib import contextmanager

from pyro import poutine
from pyro.nn import PyroModule


def _product(shape):
Expand All @@ -16,6 +18,32 @@ def _product(shape):
return result


def deep_setattr(obj, key, val):
"""
Set an attribute `key` on the object. If any of the prefix attributes do
not exist, they are set to :class:`~pyro.nn.PyroModule`.
"""

def _getattr(obj, attr):
obj_next = getattr(obj, attr, None)
if obj_next is not None:
return obj_next
setattr(obj, attr, PyroModule())
return getattr(obj, attr)

lpart, _, rpart = key.rpartition(".")
# Recursive getattr while setting any prefix attributes to PyroModule
if lpart:
obj = functools.reduce(_getattr, [obj] + lpart.split("."))
setattr(obj, rpart, val)


def deep_getattr(obj, key):
for part in key.split("."):
obj = getattr(obj, part)
return obj


def mean_field_entropy(model, args, whitelist=None):
"""Computes the entropy of a model, assuming
that the model is fully mean-field (i.e. all sample sites
Expand Down
Loading