Skip to content

Commit

Permalink
Implement AutoGaussian guide (#2929)
Browse files Browse the repository at this point in the history
* Implement front-end of AutoGaussian guide

* Implement funsor backend

* Break plates causing collision

* Fix some tests

* Support monte carlo particles

* Tweak variable order of cholesky parametrization

* Register docstring

* Split up methods, split up tests, add new tests, mark stage=funsor

* Fix get_dependencies() to handle pyro.factor

* Add failing test for plated sampling

* Add regression tests for get_dependencies()

* lint

* Use funsor.recipes.forward_filter_backward_rsample

* Fix & strengthen tests

* Reflect

* Add another exact test

* Simplify log_density computation

* Switch from precision to precision_chol parameters

* Add test of Gaussian .rsample() and .log_prob()

* Sketch dense backend

* Update docs

* Perfect precision parametrization (breaking both backends)

* Minor updates

* Change precision representation, start fixing dense backend

* Add more tests

* Refactor to use a class hierarchy

* Flesh out dense backend (some tests fail)

* Fix tests, simplify init logic

* Add a pyro-cov poisson example model

* Flesh out funsor backend

* Add tests, update docs

* Be safer about importing funsor

* Fix more tests

* Add a test

* Simplify; fix bugs

* Tweak test parameters

* Revert unnecessary change

* Mark test funsor stage

* Wrap intractable error in NotImplemented

* Fix serialization tests

* fix typo
  • Loading branch information
fritzo authored Oct 4, 2021
1 parent eadca9c commit e71145a
Show file tree
Hide file tree
Showing 10 changed files with 1,486 additions and 58 deletions.
38 changes: 23 additions & 15 deletions docs/source/infer.autoguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,71 +8,71 @@ AutoGuide
.. autoclass:: pyro.infer.autoguide.AutoGuide
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoGuideList
-------------
.. autoclass:: pyro.infer.autoguide.AutoGuideList
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoCallable
------------
.. autoclass:: pyro.infer.autoguide.AutoCallable
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoNormal
----------
.. autoclass:: pyro.infer.autoguide.AutoNormal
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoDelta
---------
.. autoclass:: pyro.infer.autoguide.AutoDelta
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoContinuous
--------------
.. autoclass:: pyro.infer.autoguide.AutoContinuous
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoMultivariateNormal
----------------------
.. autoclass:: pyro.infer.autoguide.AutoMultivariateNormal
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoDiagonalNormal
------------------
.. autoclass:: pyro.infer.autoguide.AutoDiagonalNormal
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoLowRankMultivariateNormal
-----------------------------
.. autoclass:: pyro.infer.autoguide.AutoLowRankMultivariateNormal
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:


Expand All @@ -81,7 +81,7 @@ AutoNormalizingFlow
.. autoclass:: pyro.infer.autoguide.AutoNormalizingFlow
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:


Expand All @@ -90,31 +90,39 @@ AutoIAFNormal
.. autoclass:: pyro.infer.autoguide.AutoIAFNormal
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoLaplaceApproximation
-----------------------------
.. autoclass:: pyro.infer.autoguide.AutoLaplaceApproximation
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoDiscreteParallel
--------------------
.. autoclass:: pyro.infer.autoguide.AutoDiscreteParallel
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoStructured
--------------------
.. autoclass:: pyro.infer.autoguide.AutoStructured
:members:
:undoc-members:
:special-members: __call__
:member-order: bysource
:show-inheritance:

AutoGaussian
------------
.. autoclass:: pyro.infer.autoguide.AutoGaussian
:members:
:undoc-members:
:member-order: bysource
:show-inheritance:

.. _autoguide-initialization:
Expand All @@ -125,5 +133,5 @@ Initialization
:members:
:undoc-members:
:special-members: __call__
:show-inheritance:
:member-order: bysource
:show-inheritance:
2 changes: 2 additions & 0 deletions pyro/infer/autoguide/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from pyro.infer.autoguide.gaussian import AutoGaussian
from pyro.infer.autoguide.guides import (
AutoCallable,
AutoContinuous,
Expand Down Expand Up @@ -34,6 +35,7 @@
"AutoDelta",
"AutoDiagonalNormal",
"AutoDiscreteParallel",
"AutoGaussian",
"AutoGuide",
"AutoGuideList",
"AutoIAFNormal",
Expand Down
Loading

0 comments on commit e71145a

Please sign in to comment.