Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Precondition interpretation for Gaussian TVE #553

Merged
merged 82 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from 81 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
9fc9593
Add Precondition interpretation for Gaussian TVE
fritzo Sep 24, 2021
f674d2b
Fix bugs; add a test
fritzo Sep 24, 2021
6b2db4d
Add an ops.getslice for more complex eager indexing
fritzo Sep 25, 2021
6f7b326
Fix bugs, add patterns, add more tests
fritzo Sep 25, 2021
33d382b
fix is_affine()
fritzo Sep 25, 2021
aa24595
Merge branch 'getslice-op' into precondition
fritzo Sep 25, 2021
2bf875f
Add more tests
fritzo Sep 25, 2021
0e14835
Merge branch 'master' into getslice-op
fritzo Sep 26, 2021
697c6cd
Fix eager_getslice_lambda
fritzo Sep 27, 2021
d108aab
Merge branch 'getslice-op' into precondition
fritzo Sep 27, 2021
56442c2
Merge branch 'master' into precondition
fritzo Sep 27, 2021
b500034
Merge branch 'master' into precondition
fritzo Oct 4, 2021
bf30d15
Switch to sqrt(precision) representation in Gaussian
fritzo Oct 7, 2021
6ad1952
Fix some bugs
fritzo Oct 7, 2021
15f767c
Fix more math
fritzo Oct 7, 2021
5b3c285
Add GaussianMeta conversions; fix broadcasting bug
fritzo Oct 7, 2021
6317f7f
Fix some distribution tests
fritzo Oct 7, 2021
841010a
Refactor from info_vec to white_vec
fritzo Oct 8, 2021
57a1204
Fix more tests
fritzo Oct 8, 2021
d858cd5
Flesh our matrix_and_mvn_to_funsor()
fritzo Oct 8, 2021
47afb49
Work our marginalization
fritzo Oct 10, 2021
e919c33
fix more tests
fritzo Oct 10, 2021
965bb50
Fix more tests
fritzo Oct 10, 2021
c1c8d18
Fix test_gaussian.py
fritzo Oct 11, 2021
47ab8da
Fix distribution patterns
fritzo Oct 11, 2021
fe0c7c5
Fix argmax approximation
fritzo Oct 11, 2021
10b3432
Remove Gaussian.negate attribute
fritzo Oct 11, 2021
702152b
Fix matrix_and_mvn_to_funsor diag (full still broken)
fritzo Oct 12, 2021
493edb6
Fix old uses of info_vec
fritzo Oct 12, 2021
67ad0c1
Add a test
fritzo Oct 12, 2021
2d4fdb9
Fix shape bug in matrix_and_mvn_to_funsor()
fritzo Oct 12, 2021
18674e8
Merge branch 'master' into srif
fritzo Oct 12, 2021
eeda90d
Enable pprint for funsors
fritzo Oct 12, 2021
5f17da8
Revert pp property
fritzo Oct 12, 2021
be11455
Merge branch 'pprint' into srif
fritzo Oct 12, 2021
d7dfd20
Fix matrix_and_mvn_to_funsor()
fritzo Oct 12, 2021
f99682a
Relax rank condition
fritzo Oct 12, 2021
b5bee71
Merge branch 'master' into srif
fritzo Oct 12, 2021
cc1e08c
Fix ._sample()
fritzo Oct 12, 2021
435119a
Fix eager_contraction_to_binary
fritzo Oct 12, 2021
c225b59
Fix test_joint.py
fritzo Oct 12, 2021
f279dd3
Fix comparisons in sequential sum product
fritzo Oct 13, 2021
2efa851
Fix saarka bilmes test
fritzo Oct 13, 2021
8c301dd
Add and xfail tests of singular matrices
fritzo Oct 13, 2021
25e8c87
Fix rank deficiency issues
fritzo Oct 13, 2021
60cc8e5
Add gaussian integrate patterns
fritzo Oct 13, 2021
631e06c
Fix comment
fritzo Oct 13, 2021
503ffd7
Add a set_compression_threshold context manager
fritzo Oct 13, 2021
22479dc
Update docstring
fritzo Oct 13, 2021
8aa123d
Merge branch 'master' into srif
fritzo Oct 13, 2021
639ed0b
Fix backward sampling support bug
fritzo Oct 13, 2021
76d8bcd
Xfail test_elbo.py::test_complex
fritzo Oct 13, 2021
c709453
Relax test thresholds
fritzo Oct 13, 2021
c8ff3a9
Fix ops.qr numpy backend
fritzo Oct 13, 2021
503383b
Fix jax tests
fritzo Oct 13, 2021
ec499b0
Fix bugs
fritzo Oct 13, 2021
f5d8519
Tweak sensor example
fritzo Oct 13, 2021
ecc249b
Merge branch 'master' into precondition
fritzo Oct 14, 2021
3f7af74
Merge branch 'srif' into precondition
fritzo Oct 14, 2021
f39019b
Fix bugs
fritzo Oct 14, 2021
577189b
Add more precondition approximate patterns
fritzo Oct 15, 2021
5f468aa
Address review comments
fritzo Oct 16, 2021
4d1af1a
Merge branch 'srif' into precondition
fritzo Oct 16, 2021
bcce722
Add Sub[Gaussian, tuple] pattern
fritzo Oct 17, 2021
8eb9d8a
Sketch implementation of partial sampling from Gaussians
fritzo Oct 18, 2021
3339b30
Fix bug
fritzo Oct 18, 2021
3b544cb
Fix a bug in partial sampling
fritzo Oct 18, 2021
d8a9919
Get partial sampling working
fritzo Oct 18, 2021
bb71c2c
Merge branch 'master' into precondition
fritzo Oct 18, 2021
673b64a
Reorder Gaussians in cnf
fritzo Oct 18, 2021
afbe277
Fix batch shape computation
fritzo Oct 19, 2021
7691e9c
Add pattern to fuse nested Subs
fritzo Oct 19, 2021
c1dcbc1
Merge branch 'fuse-subs' into precondition
fritzo Oct 19, 2021
89ef6d5
Relax tolerance
fritzo Oct 19, 2021
ca5b49b
Fix eager_finitary_cat
fritzo Oct 19, 2021
893e8ff
Merge branch 'master' into precondition
fritzo Oct 19, 2021
0e47b85
Increase sample count
fritzo Oct 20, 2021
8dce21e
Fix jax backend for ops.randn
fritzo Oct 20, 2021
6a303f6
Revert Gaussian - Gaussian pattern
fritzo Oct 27, 2021
ba2b740
Relax tolerance
fritzo Oct 27, 2021
fc13611
Merge branch 'master' into precondition
fritzo Nov 12, 2021
ca6cae7
Remove obsolete test
fritzo Nov 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/interpretations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ Monte Carlo
:show-inheritance:
:member-order: bysource

Preconditioning
---------------
.. automodule:: funsor.precondition
:members:
:show-inheritance:
:member-order: bysource

Approximations
--------------
.. automodule:: funsor.approximations
Expand Down
2 changes: 2 additions & 0 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
joint,
montecarlo,
ops,
precondition,
recipes,
sum_product,
terms,
Expand Down Expand Up @@ -102,6 +103,7 @@
"montecarlo",
"of_shape",
"ops",
"precondition",
"pretty",
"quote",
"reals",
Expand Down
7 changes: 2 additions & 5 deletions funsor/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ def __enter__(self):
self._old_interpretation = interpreter.get_interpretation()
return super().__enter__()

def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=frozenset()):
# TODO Replace this with root + Constant(...) after #548 merges.
root_vars = root.input_vars | batch_vars

def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=set()):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that batch_vars can now change during the course of adjoint, e.g. in Precondition where the aux vars aren't know until each Approximate term is hit.

zero = to_funsor(ops.UNITS[sum_op])
one = to_funsor(ops.UNITS[bin_op])
adjoint_values = defaultdict(lambda: zero)
Expand Down Expand Up @@ -118,7 +115,7 @@ def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=frozenset())
in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs)
for v, adjv in in_adjs:
# Marginalize out message variables that don't appear in recipients.
agg_vars = adjv.input_vars - v.input_vars - root_vars
agg_vars = adjv.input_vars - v.input_vars - root.input_vars - batch_vars
assert "particle" not in {var.name for var in agg_vars} # DEBUG FIXME
old_value = adjoint_values[v]
adjoint_values[v] = sum_op(old_value, adjv.reduce(sum_op, agg_vars))
Expand Down
21 changes: 15 additions & 6 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
sampled_vars = sampled_vars.intersection(self.inputs)
if not sampled_vars:
return self
for term in self.terms:
if isinstance(term, Delta):
sampled_vars -= term.fresh
if not sampled_vars:
return self

if self.red_op in (ops.null, ops.logaddexp):
if rng_key is not None and get_backend() == "jax":
Expand All @@ -116,8 +121,8 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
rng_keys = [None] * len(self.terms)

if self.bin_op in (ops.null, ops.logaddexp):
# Design choice: we sample over logaddexp reductions, but leave logaddexp
# binary choices symbolic.
# Design choice: we sample over logaddexp reductions, but leave
# logaddexp binary choices symbolic.
terms = [
term._sample(
sampled_vars.intersection(term.inputs), sample_inputs, rng_key
Expand All @@ -132,11 +137,15 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
greedy_vars = sampled_vars.intersection(term.inputs)
if greedy_vars:
break
assert greedy_vars
greedy_terms, terms = [], []
for term in self.terms:
(
terms if greedy_vars.isdisjoint(term.inputs) else greedy_terms
).append(term)
if greedy_vars.isdisjoint(term.inputs):
terms.append(term)
elif isinstance(term, Delta) and greedy_vars.isdisjoint(term.fresh):
terms.append(term)
else:
greedy_terms.append(term)
if len(greedy_terms) == 1:
term = greedy_terms[0]
terms.append(term._sample(greedy_vars, sample_inputs, rng_keys[0]))
Expand Down Expand Up @@ -392,7 +401,7 @@ def _(fn):
# Normalizing Contractions
##########################################

ORDERING = {Delta: 1, Number: 2, Tensor: 3, Gaussian: 4}
ORDERING = {Delta: 1, Number: 2, Tensor: 3, Gaussian: 4, Unary[ops.NegOp, Gaussian]: 5}
GROUND_TERMS = tuple(ORDERING)


Expand Down
16 changes: 16 additions & 0 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,22 @@ def _find_domain_stack(op, parts):
return output


@find_domain.register(ops.CatOp)
def _find_domain_cat(op, parts):
dim = op.defaults["axis"]
if dim >= 0:
event_dims = {len(x.shape) for x in parts}
assert len(event_dims) == 1, "undefined"
dim = dim - next(iter(event_dims))
assert dim < 0
shape = broadcast_shape(*(x.shape[:dim] for x in parts))
shape += (sum(x.shape[dim] for x in parts),)
if dim < -1:
shape += broadcast_shape(*(x.shape[dim + 1 :] for x in parts))
output = Array[parts[0].dtype, shape]
return output


@find_domain.register(ops.EinsumOp)
def _find_domain_einsum(op, operands):
equation = op.defaults["equation"]
Expand Down
Loading