Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to sqrt(precision) representation in Gaussian #568

Merged
merged 46 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
bf30d15
Switch to sqrt(precision) representation in Gaussian
fritzo Oct 7, 2021
6ad1952
Fix some bugs
fritzo Oct 7, 2021
15f767c
Fix more math
fritzo Oct 7, 2021
5b3c285
Add GaussianMeta conversions; fix broadcasting bug
fritzo Oct 7, 2021
6317f7f
Fix some distribution tests
fritzo Oct 7, 2021
841010a
Refactor from info_vec to white_vec
fritzo Oct 8, 2021
57a1204
Fix more tests
fritzo Oct 8, 2021
d858cd5
Flesh our matrix_and_mvn_to_funsor()
fritzo Oct 8, 2021
47afb49
Work our marginalization
fritzo Oct 10, 2021
e919c33
fix more tests
fritzo Oct 10, 2021
965bb50
Fix more tests
fritzo Oct 10, 2021
c1c8d18
Fix test_gaussian.py
fritzo Oct 11, 2021
47ab8da
Fix distribution patterns
fritzo Oct 11, 2021
fe0c7c5
Fix argmax approximation
fritzo Oct 11, 2021
10b3432
Remove Gaussian.negate attribute
fritzo Oct 11, 2021
702152b
Fix matrix_and_mvn_to_funsor diag (full still broken)
fritzo Oct 12, 2021
493edb6
Fix old uses of info_vec
fritzo Oct 12, 2021
67ad0c1
Add a test
fritzo Oct 12, 2021
2d4fdb9
Fix shape bug in matrix_and_mvn_to_funsor()
fritzo Oct 12, 2021
18674e8
Merge branch 'master' into srif
fritzo Oct 12, 2021
eeda90d
Enable pprint for funsors
fritzo Oct 12, 2021
5f17da8
Revert pp property
fritzo Oct 12, 2021
be11455
Merge branch 'pprint' into srif
fritzo Oct 12, 2021
d7dfd20
Fix matrix_and_mvn_to_funsor()
fritzo Oct 12, 2021
f99682a
Relax rank condition
fritzo Oct 12, 2021
b5bee71
Merge branch 'master' into srif
fritzo Oct 12, 2021
cc1e08c
Fix ._sample()
fritzo Oct 12, 2021
435119a
Fix eager_contraction_to_binary
fritzo Oct 12, 2021
c225b59
Fix test_joint.py
fritzo Oct 12, 2021
f279dd3
Fix comparisons in sequential sum product
fritzo Oct 13, 2021
2efa851
Fix saarka bilmes test
fritzo Oct 13, 2021
8c301dd
Add and xfail tests of singular matrices
fritzo Oct 13, 2021
25e8c87
Fix rank deficiency issues
fritzo Oct 13, 2021
60cc8e5
Add gaussian integrate patterns
fritzo Oct 13, 2021
631e06c
Fix comment
fritzo Oct 13, 2021
503ffd7
Add a set_compression_threshold context manager
fritzo Oct 13, 2021
22479dc
Update docstring
fritzo Oct 13, 2021
8aa123d
Merge branch 'master' into srif
fritzo Oct 13, 2021
639ed0b
Fix backward sampling support bug
fritzo Oct 13, 2021
76d8bcd
Xfail test_elbo.py::test_complex
fritzo Oct 13, 2021
c709453
Relax test thresholds
fritzo Oct 13, 2021
c8ff3a9
Fix ops.qr numpy backend
fritzo Oct 13, 2021
503383b
Fix jax tests
fritzo Oct 13, 2021
ec499b0
Fix bugs
fritzo Oct 13, 2021
f5d8519
Tweak sensor example
fritzo Oct 13, 2021
5f468aa
Address review comments
fritzo Oct 16, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ jobs:
strategy:
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]

env:
CI: 1
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -29,8 +30,7 @@ jobs:
pip install .[test]
pip freeze
- name: Run test
run: |
make test
run: make test


torch:
Expand All @@ -39,7 +39,9 @@ jobs:
strategy:
matrix:
python-version: [3.6]

env:
CI: 1
FUNSOR_BACKEND: torch
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -57,9 +59,7 @@ jobs:
pip install .[test,torch]
pip freeze
- name: Run test
run: |
make test
FUNSOR_BACKEND=torch make test
run: make test


jax:
Expand All @@ -68,7 +68,9 @@ jobs:
strategy:
matrix:
python-version: [3.6]

env:
CI: 1
FUNSOR_BACKEND: jax
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -85,5 +87,4 @@ jobs:
pip install .[test,jax]
pip freeze
- name: Run test
run: |
CI=1 FUNSOR_BACKEND=jax make test
run: make test
9 changes: 6 additions & 3 deletions examples/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def generate_data(num_frames, num_sensors):
]
)
trans_dist = dist.MultivariateNormal(
torch.zeros(4), scale_tril=trans_noise * NCV_PROCESS_NOISE.cholesky()
torch.zeros(4),
scale_tril=trans_noise * torch.linalg.cholesky(NCV_PROCESS_NOISE),
)

# define biased sensors
Expand Down Expand Up @@ -128,7 +129,7 @@ def forward(self, observations, add_bias=True):
curr = Variable("curr", Reals[4])
self.trans_dist = f_dist.MultivariateNormal(
loc=prev @ NCV_TRANSITION_MATRIX,
scale_tril=trans_noise * NCV_PROCESS_NOISE.cholesky(),
scale_tril=trans_noise * torch.linalg.cholesky(NCV_PROCESS_NOISE),
value=curr,
)

Expand Down Expand Up @@ -239,7 +240,9 @@ def main(args):
or not args.metrics_filename
or not os.path.exists(args.metrics_filename)
):
results = track(args)
# Increase compression threshold for numerical stability.
with funsor.gaussian.Gaussian.set_compression_threshold(3):
results = track(args)
else:
results = torch.load(args.metrics_filename)

Expand Down
5 changes: 1 addition & 4 deletions funsor/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,7 @@ def compute_argmax_gaussian(model, approx_vars):

approx_names = frozenset(v.name for v in approx_vars)
if approx_names == frozenset(real_inputs):
x = model.info_vec[..., None]
x = ops.triangular_solve(x, model._precision_chol)
x = ops.triangular_solve(x, model._precision_chol, transpose=True)
mode = x[..., 0]
mode = model._mean
offsets, _ = _compute_offsets(real_inputs)
result = {}
for key, domain in real_inputs.items():
Expand Down
33 changes: 16 additions & 17 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,19 +697,12 @@ def indep_to_data(funsor_dist, name_to_dim=None):


@to_data.register(Gaussian)
def gaussian_to_data(funsor_dist, name_to_dim=None, normalized=False):
if normalized:
return to_data(
funsor_dist.log_normalizer + funsor_dist, name_to_dim=name_to_dim
)
loc = ops.cholesky_solve(
ops.unsqueeze(funsor_dist.info_vec, -1), ops.cholesky(funsor_dist.precision)
).squeeze(-1)
def gaussian_to_data(funsor_dist, name_to_dim=None):
int_inputs = OrderedDict(
(k, d) for k, d in funsor_dist.inputs.items() if d.dtype != "real"
)
loc = to_data(Tensor(loc, int_inputs), name_to_dim)
precision = to_data(Tensor(funsor_dist.precision, int_inputs), name_to_dim)
loc = to_data(Tensor(funsor_dist._mean, int_inputs), name_to_dim)
precision = to_data(Tensor(funsor_dist._precision, int_inputs), name_to_dim)
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
return backend_dist.MultivariateNormal.dist_class(loc, precision_matrix=precision)

Expand Down Expand Up @@ -845,13 +838,17 @@ def eager_normal(loc, scale, value):
if not is_affine(loc) or not is_affine(value):
return None # lazy

info_vec = ops.new_zeros(scale.data, scale.data.shape + (1,))
precision = ops.pow(scale.data, -2).reshape(scale.data.shape + (1, 1))
log_prob = -0.5 * math.log(2 * math.pi) - ops.log(scale).sum()
white_vec = ops.new_zeros(scale.data, scale.data.shape + (1,))
prec_sqrt = (1 / scale.data)[..., None, None]
log_prob = -0.5 * math.log(2 * math.pi) - ops.log(scale)
inputs = scale.inputs.copy()
var = gensym("value")
inputs[var] = Real
gaussian = log_prob + Gaussian(info_vec, precision, inputs)
gaussian = log_prob + Gaussian(
white_vec=white_vec,
prec_sqrt=prec_sqrt,
inputs=inputs,
)
return gaussian(**{var: value - loc})


Expand All @@ -862,16 +859,18 @@ def eager_mvn(loc, scale_tril, value):
if not is_affine(loc) or not is_affine(value):
return None # lazy

info_vec = ops.new_zeros(scale_tril.data, scale_tril.data.shape[:-1])
precision = ops.cholesky_inverse(scale_tril.data)
white_vec = ops.new_zeros(scale_tril.data, scale_tril.data.shape[:-1])
prec_sqrt = ops.transpose(ops.triangular_inv(scale_tril.data), -1, -2)
scale_diag = Tensor(ops.diagonal(scale_tril.data, -1, -2), scale_tril.inputs)
log_prob = (
-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - ops.log(scale_diag).sum()
)
inputs = scale_tril.inputs.copy()
var = gensym("value")
inputs[var] = Reals[scale_diag.shape[0]]
gaussian = log_prob + Gaussian(info_vec, precision, inputs)
gaussian = log_prob + Gaussian(
white_vec=white_vec, prec_sqrt=prec_sqrt, inputs=inputs
)
return gaussian(**{var: value - loc})


Expand Down
Loading