Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Precondition interpretation for Gaussian TVE #553

Merged
merged 82 commits into from
Nov 12, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
9fc9593
Add Precondition interpretation for Gaussian TVE
fritzo Sep 24, 2021
f674d2b
Fix bugs; add a test
fritzo Sep 24, 2021
6b2db4d
Add an ops.getslice for more complex eager indexing
fritzo Sep 25, 2021
6f7b326
Fix bugs, add patterns, add more tests
fritzo Sep 25, 2021
33d382b
fix is_affine()
fritzo Sep 25, 2021
aa24595
Merge branch 'getslice-op' into precondition
fritzo Sep 25, 2021
2bf875f
Add more tests
fritzo Sep 25, 2021
0e14835
Merge branch 'master' into getslice-op
fritzo Sep 26, 2021
697c6cd
Fix eager_getslice_lambda
fritzo Sep 27, 2021
d108aab
Merge branch 'getslice-op' into precondition
fritzo Sep 27, 2021
56442c2
Merge branch 'master' into precondition
fritzo Sep 27, 2021
b500034
Merge branch 'master' into precondition
fritzo Oct 4, 2021
bf30d15
Switch to sqrt(precision) representation in Gaussian
fritzo Oct 7, 2021
6ad1952
Fix some bugs
fritzo Oct 7, 2021
15f767c
Fix more math
fritzo Oct 7, 2021
5b3c285
Add GaussianMeta conversions; fix broadcasting bug
fritzo Oct 7, 2021
6317f7f
Fix some distribution tests
fritzo Oct 7, 2021
841010a
Refactor from info_vec to white_vec
fritzo Oct 8, 2021
57a1204
Fix more tests
fritzo Oct 8, 2021
d858cd5
Flesh our matrix_and_mvn_to_funsor()
fritzo Oct 8, 2021
47afb49
Work our marginalization
fritzo Oct 10, 2021
e919c33
fix more tests
fritzo Oct 10, 2021
965bb50
Fix more tests
fritzo Oct 10, 2021
c1c8d18
Fix test_gaussian.py
fritzo Oct 11, 2021
47ab8da
Fix distribution patterns
fritzo Oct 11, 2021
fe0c7c5
Fix argmax approximation
fritzo Oct 11, 2021
10b3432
Remove Gaussian.negate attribute
fritzo Oct 11, 2021
702152b
Fix matrix_and_mvn_to_funsor diag (full still broken)
fritzo Oct 12, 2021
493edb6
Fix old uses of info_vec
fritzo Oct 12, 2021
67ad0c1
Add a test
fritzo Oct 12, 2021
2d4fdb9
Fix shape bug in matrix_and_mvn_to_funsor()
fritzo Oct 12, 2021
18674e8
Merge branch 'master' into srif
fritzo Oct 12, 2021
eeda90d
Enable pprint for funsors
fritzo Oct 12, 2021
5f17da8
Revert pp property
fritzo Oct 12, 2021
be11455
Merge branch 'pprint' into srif
fritzo Oct 12, 2021
d7dfd20
Fix matrix_and_mvn_to_funsor()
fritzo Oct 12, 2021
f99682a
Relax rank condition
fritzo Oct 12, 2021
b5bee71
Merge branch 'master' into srif
fritzo Oct 12, 2021
cc1e08c
Fix ._sample()
fritzo Oct 12, 2021
435119a
Fix eager_contraction_to_binary
fritzo Oct 12, 2021
c225b59
Fix test_joint.py
fritzo Oct 12, 2021
f279dd3
Fix comparisons in sequential sum product
fritzo Oct 13, 2021
2efa851
Fix saarka bilmes test
fritzo Oct 13, 2021
8c301dd
Add and xfail tests of singular matrices
fritzo Oct 13, 2021
25e8c87
Fix rank deficiency issues
fritzo Oct 13, 2021
60cc8e5
Add gaussian integrate patterns
fritzo Oct 13, 2021
631e06c
Fix comment
fritzo Oct 13, 2021
503ffd7
Add a set_compression_threshold context manager
fritzo Oct 13, 2021
22479dc
Update docstring
fritzo Oct 13, 2021
8aa123d
Merge branch 'master' into srif
fritzo Oct 13, 2021
639ed0b
Fix backward sampling support bug
fritzo Oct 13, 2021
76d8bcd
Xfail test_elbo.py::test_complex
fritzo Oct 13, 2021
c709453
Relax test thresholds
fritzo Oct 13, 2021
c8ff3a9
Fix ops.qr numpy backend
fritzo Oct 13, 2021
503383b
Fix jax tests
fritzo Oct 13, 2021
ec499b0
Fix bugs
fritzo Oct 13, 2021
f5d8519
Tweak sensor example
fritzo Oct 13, 2021
ecc249b
Merge branch 'master' into precondition
fritzo Oct 14, 2021
3f7af74
Merge branch 'srif' into precondition
fritzo Oct 14, 2021
f39019b
Fix bugs
fritzo Oct 14, 2021
577189b
Add more precondition approximate patterns
fritzo Oct 15, 2021
5f468aa
Address review comments
fritzo Oct 16, 2021
4d1af1a
Merge branch 'srif' into precondition
fritzo Oct 16, 2021
bcce722
Add Sub[Gaussian, tuple] pattern
fritzo Oct 17, 2021
8eb9d8a
Sketch implementation of partial sampling from Gaussians
fritzo Oct 18, 2021
3339b30
Fix bug
fritzo Oct 18, 2021
3b544cb
Fix a bug in partial sampling
fritzo Oct 18, 2021
d8a9919
Get partial sampling working
fritzo Oct 18, 2021
bb71c2c
Merge branch 'master' into precondition
fritzo Oct 18, 2021
673b64a
Reorder Gaussians in cnf
fritzo Oct 18, 2021
afbe277
Fix batch shape computation
fritzo Oct 19, 2021
7691e9c
Add pattern to fuse nested Subs
fritzo Oct 19, 2021
c1dcbc1
Merge branch 'fuse-subs' into precondition
fritzo Oct 19, 2021
89ef6d5
Relax tolerance
fritzo Oct 19, 2021
ca5b49b
Fix eager_finitary_cat
fritzo Oct 19, 2021
893e8ff
Merge branch 'master' into precondition
fritzo Oct 19, 2021
0e47b85
Increase sample count
fritzo Oct 20, 2021
8dce21e
Fix jax backend for ops.randn
fritzo Oct 20, 2021
6a303f6
Revert Gaussian - Gaussian pattern
fritzo Oct 27, 2021
ba2b740
Relax tolerance
fritzo Oct 27, 2021
fc13611
Merge branch 'master' into precondition
fritzo Nov 12, 2021
ca6cae7
Remove obsolete test
fritzo Nov 12, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/interpretations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ Monte Carlo
:show-inheritance:
:member-order: bysource

Preconditioning
---------------
.. automodule:: funsor.precondition
:members:
:show-inheritance:
:member-order: bysource

Approximations
--------------
.. automodule:: funsor.approximations
Expand Down
2 changes: 2 additions & 0 deletions funsor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
joint,
montecarlo,
ops,
precondition,
recipes,
sum_product,
terms,
Expand Down Expand Up @@ -98,6 +99,7 @@
"montecarlo",
"of_shape",
"ops",
"precondition",
"pretty",
"quote",
"reals",
Expand Down
7 changes: 2 additions & 5 deletions funsor/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ def __enter__(self):
self._old_interpretation = interpreter.get_interpretation()
return super().__enter__()

def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=frozenset()):
# TODO Replace this with root + Constant(...) after #548 merges.
root_vars = root.input_vars | batch_vars

def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=set()):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that batch_vars can now change during the course of adjoint, e.g. in Precondition where the aux vars aren't know until each Approximate term is hit.

zero = to_funsor(ops.UNITS[sum_op])
one = to_funsor(ops.UNITS[bin_op])
adjoint_values = defaultdict(lambda: zero)
Expand Down Expand Up @@ -118,7 +115,7 @@ def adjoint(self, sum_op, bin_op, root, targets=None, *, batch_vars=frozenset())
in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs)
for v, adjv in in_adjs:
# Marginalize out message variables that don't appear in recipients.
agg_vars = adjv.input_vars - v.input_vars - root_vars
agg_vars = adjv.input_vars - v.input_vars - root.input_vars - batch_vars
assert "particle" not in {var.name for var in agg_vars} # DEBUG FIXME
old_value = adjoint_values[v]
adjoint_values[v] = sum_op(old_value, adjv.reduce(sum_op, agg_vars))
Expand Down
4 changes: 3 additions & 1 deletion funsor/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def _(fn):

@affine_inputs.register(Unary)
def _(fn):
if fn.op in (ops.neg, ops.sum) or isinstance(fn.op, ops.ReshapeOp):
if fn.op in (ops.neg, ops.sum) or isinstance(
fn.op, (ops.ReshapeOp, ops.GetsliceOp)
):
return affine_inputs(fn.arg)
return frozenset()

Expand Down
53 changes: 53 additions & 0 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from weakref import WeakValueDictionary

import funsor.ops as ops
from funsor.ops.builtin import parse_ellipsis, parse_slice
from funsor.util import broadcast_shape, get_backend, get_tracing_state, quote

Domain = type
Expand Down Expand Up @@ -331,6 +332,58 @@ def _find_domain_getitem(op, lhs_domain, rhs_domain):
)


@find_domain.register(ops.GetsliceOp)
def _find_domain_getslice(op, domain):
index = op.defaults["index"]
if isinstance(domain, ArrayType):
dtype = domain.dtype
shape = list(domain.shape)
left, right = parse_ellipsis(index)

i = 0
for part in left:
if part is None:
shape.insert(i, 1)
i += 1
elif isinstance(part, int):
del shape[i]
elif isinstance(part, slice):
start, stop, step = parse_slice(part, shape[i])
shape[i] = max(0, (stop - start + step - 1) // step)
i += 1
else:
raise ValueError(part)

i = -1
for part in reversed(right):
if part is None:
shape.insert(len(shape) + i + 1, 1)
i -= 1
elif isinstance(part, int):
del shape[i]
elif isinstance(part, slice):
start, stop, step = parse_slice(part, shape[i])
shape[i] = max(0, (stop - start + step - 1) // step)
i -= 1
else:
raise ValueError(part)

return Array[dtype, tuple(shape)]

if isinstance(domain, ProductDomain):
if isinstance(index, tuple):
assert len(index) == 1
index = index[0]
if isinstance(index, int):
return domain.__args__[index]
elif isinstance(index, slice):
return Product[domain.__args__[index]]
else:
raise ValueError(index)

raise NotImplementedError("TODO")


@find_domain.register(ops.BinaryOp)
def _find_domain_pointwise_binary_generic(op, lhs, rhs):
if (
Expand Down
51 changes: 36 additions & 15 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ def _eager_subs_real(self, subs, remaining_subs):
assert value.shape[-1] == self.inputs[k].num_elements
values[k] = ops.expand(value, batch_shape + value.shape[-1:])

# Try to perform a complete substitution of all real variables, resulting in a Tensor.
# Try to perform a complete substitution of all real variables,
# resulting in a Tensor.
if all(k in subs for k, d in self.inputs.items() if d.dtype == "real"):
# Form the concatenated value.
value = BlockVector(batch_shape + (event_size,))
Expand All @@ -461,8 +462,9 @@ def _eager_subs_real(self, subs, remaining_subs):
assert result.output == Real
return Subs(result, remaining_subs) if remaining_subs else result

# Perform a partial substution of a subset of real variables, resulting in a Joint.
# We split real inputs into two sets: a for the preserved and b for the substituted.
# Perform a partial substution of a subset of real variables, resulting
# in a Joint. We split real inputs into two sets: a for the preserved
# and b for the substituted.
b = frozenset(k for k, v in subs.items())
a = frozenset(
k for k, d in self.inputs.items() if d.dtype == "real" and k not in b
Expand Down Expand Up @@ -703,7 +705,9 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
sample_inputs = OrderedDict(
(k, d) for k, d in sample_inputs.items() if k not in self.inputs
)
sample_shape = tuple(int(d.dtype) for d in sample_inputs.values())
sample_shape = tuple(
d.size for d in sample_inputs.values() if d.dtype != "real"
)
int_inputs = OrderedDict(
(k, d) for k, d in self.inputs.items() if d.dtype != "real"
)
Expand All @@ -716,32 +720,49 @@ def _sample(self, sampled_vars, sample_inputs, rng_key):
if sampled_vars == frozenset(real_inputs):
shape = sample_shape + self.info_vec.shape
backend = get_backend()
if backend != "numpy":
info_vec = self.info_vec
precision_chol = self._precision_chol
if (
len(sample_inputs) == 1
and next(iter(sample_inputs.values())).dtype == "real"
):
# Lazily compute a sample as a function of white noise.
for k, d in sample_inputs.items():
white_noise = Variable(k, d)[tuple(int_inputs)]
info_vec = Tensor(info_vec, int_inputs)
precision_chol = Tensor(precision_chol, int_inputs)
elif backend == "numpy":
# Eagerly draw noise.
white_noise = np.random.randn(*shape)
else:
# Eagerly draw noise.
from importlib import import_module

dist = import_module(
funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend]
)
sample_args = (shape,) if rng_key is None else (rng_key, shape)
white_noise = dist.Normal.dist_class(0, 1).sample(*sample_args)
else:
white_noise = np.random.randn(*shape)
white_noise = ops.unsqueeze(white_noise, -1)

white_vec = ops.triangular_solve(
self.info_vec[..., None], self._precision_chol
)
# Jointly sample.
# This section may involve either Funsors or backend arrays.
white_vec = ops.triangular_solve(info_vec[..., None], precision_chol)
sample = ops.triangular_solve(
white_noise + white_vec, self._precision_chol, transpose=True
white_noise[..., None] + white_vec, precision_chol, transpose=True
)[..., 0]

# Extract shaped components.
offsets, _ = _compute_offsets(real_inputs)
results = []
for key, domain in real_inputs.items():
data = sample[..., offsets[key] : offsets[key] + domain.num_elements]
data = data.reshape(shape[:-1] + domain.shape)
point = Tensor(data, inputs)
# TODO Support nontrivial slices in Funsor.__getitem__().
point = sample[..., offsets[key] : offsets[key] + domain.num_elements]
point = point.reshape(point.shape[:-1] + domain.shape)
if not isinstance(point, Funsor): # I.e. when eagerly sampling.
point = Tensor(point, inputs)
assert point.output == domain
results.append(Delta(key, point))

results.append(self.log_normalizer)
return reduce(ops.add, results)

Expand Down
5 changes: 4 additions & 1 deletion funsor/montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

from funsor.cnf import Contraction
from funsor.delta import Delta
from funsor.gaussian import Gaussian
from funsor.integrate import Integrate
from funsor.interpretations import StatefulInterpretation
from funsor.tensor import Tensor
from funsor.terms import Approximate, Funsor, Number
from funsor.terms import Approximate, Funsor, Number, Subs
from funsor.util import get_backend

from . import ops
Expand Down Expand Up @@ -86,8 +87,10 @@ def _extract_samples_contraction(discrete_density):
return result


@extract_samples.register(Subs)
@extract_samples.register(Number)
@extract_samples.register(Tensor)
@extract_samples.register(Gaussian)
def _extract_samples_scale(discrete_density):
return {}

Expand Down
101 changes: 101 additions & 0 deletions funsor/ops/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
UNITS,
BinaryOp,
Op,
OpMeta,
TransformOp,
UnaryOp,
declare_op_types,
Expand Down Expand Up @@ -43,6 +44,105 @@ def getitem(lhs, rhs, offset=0):
return lhs[(slice(None),) * offset + (rhs,)]


class GetsliceMeta(OpMeta):
"""
Works around slice objects not being hashable.
"""

def hash_args_kwargs(cls, args, kwargs):
index = args[0] if args else kwargs["index"]
if not isinstance(index, tuple):
index = (index,)
key = tuple(
(x.start, x.stop, x.step) if isinstance(x, slice) else x for x in index
)
return key


@UnaryOp.make(metaclass=GetsliceMeta)
def getslice(x, index=Ellipsis):
return x[index]


getslice.supported_types = (type(None), type(Ellipsis), int, slice)


def parse_ellipsis(index):
"""
Helper to split a slice into parts left and right of Ellipses.

:param index: A tuple, or other object (None, int, slice, Funsor).
:returns: a pair of tuples ``left, right``.
:rtype: tuple
"""
if not isinstance(index, tuple):
index = (index,)
left = []
i = 0
for part in index:
i += 1
if part is Ellipsis:
break
left.append(part)
right = []
for part in reversed(index[i:]):
if part is Ellipsis:
break
right.append(part)
right.reverse()
return tuple(left), tuple(right)


def normalize_ellipsis(index, size):
"""
Expand Ellipses in an index to fill the given number of dimensions.

This should satisfy the equation::

x[i] == x[normalize_ellipsis(i, len(x.shape))]
"""
left, right = parse_ellipsis(index)
if len(left) + len(right) > size:
raise ValueError(f"Index is too wide: {index}")
middle = (slice(None),) * (size - len(left) - len(right))
return left + middle + right


def parse_slice(s, size):
"""
Helper to determine nonnegative integers (start, stop, step) of a slice.

:param slice s: A slice.
:param int size: The size of the array being indexed into.
:returns: A tuple of nonnegative integers ``start, stop, step``.
:rtype: tuple
"""
start = s.start
if start is None:
start = 0
assert isinstance(start, int)
if start >= 0:
start = min(size, start)
else:
start = max(0, size + start)

stop = s.stop
if stop is None:
stop = size
assert isinstance(stop, int)
if stop >= 0:
stop = min(size, stop)
else:
stop = max(0, size + stop)

step = s.step
if step is None:
step = 1
assert isinstance(step, int)

return start, stop, step


abs = UnaryOp.make(_builtin_abs)
eq = BinaryOp.make(operator.eq)
ge = BinaryOp.make(operator.ge)
Expand Down Expand Up @@ -194,6 +294,7 @@ def sigmoid_log_abs_det_jacobian(x, y):
"floordiv",
"ge",
"getitem",
"getslice",
"gt",
"invert",
"le",
Expand Down
Loading