Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre-conditioning matrix to Barker proposal #731

Merged
merged 26 commits into from
Oct 8, 2024

Conversation

ismael-mendoza
Copy link
Contributor

@ismael-mendoza ismael-mendoza commented Sep 1, 2024

Overview

This PR attempts to add a preconditioning matrix to the barker proposal implementation, as in appendix G of https://arxiv.org/abs/1908.11812

Discussion / Questions

  • I changed the _compute_acceptance_probability to be flat given how it was necessary to matrix-multiply with the pre-conditioning matrix. Is that OK?
  • Do I need to update the PDF? (not used anywhere in the codebase)
  • It's not very clear in the Barker paper but the step_size (sigma) is a 'global scale' and the inverse mass matrix is separate. See implementation from authors here (in R) - https://github.com/gzanella/barker/blob/master/functions.R

Thank you for opening a PR!

A few important guidelines and requirements before we can merge your PR:

  • If I add a new sampler, there is an issue discussing it already;
  • We should be able to understand what the PR does from its title only;
  • There is a high-level description of the changes;
  • There are links to all the relevant issues, discussions and PRs;
  • The branch is rebased on the latest main commit;
  • Commit messages follow these guidelines;
  • The code respects the current naming conventions;
  • Docstrings follow the numpy style guide
  • pre-commit is installed and configured on your machine, and you ran it before opening the PR;
  • There are tests covering the changes;
  • The doc is up-to-date;
  • If I add a new sampler* I added/updated related examples

Consider opening a Draft PR if your work is still in progress but you would like some feedback from other contributors.

This is a first draft of adding the pre-conditioning to the Barker
proposal. This follows Algorithms 4 and 5 in Appendix G of the original
Barker proposal paper. It's somewhat unclear from the paper, but the
separate step size that was already implemented serves as a global
scale for the normal distribution of the proposal. The function
`_compute_acceptance_probability` now takes in the transpose sqrt mass
matrix and the inverse, also it has been flattened to accomodate
the corresponding matrix multiplicatios.
The original docstring of step_size was incorrect, there is no
sympletic integrator.
We make this possible by adding an identity pre-conditining matrix,
which should make the test run in the same way as before.
We add a new test to barker.py to ensure that our implementation of
the preconditioning matrix is correct. We follow Appendix G in the
paper that mentions that algorithm 4 and 5 (which we implemented)
should be equivalent to rescaling the parameters and the logdensity
in a specific way. We implement both approaches when using the barker
proposal to infer the mean and sigma of a normal distribution. We
check that with two different random seeds the chains outputted are
equivalent up to some tolerance.

We also patch the original test in this file by adding an identity
mass matrix.
@ismael-mendoza
Copy link
Contributor Author

ismael-mendoza commented Sep 1, 2024

@AdrienCorenflos I just added a test to specifically check the new implementation with the mass matrix and fixed all the tests that use Barker (by setting the inverse mass matrix to the identity). I'm open to suggestions on the new test or on any other changes. Thank you!

@ismael-mendoza ismael-mendoza marked this pull request as ready for review September 1, 2024 21:18
@AdrienCorenflos AdrienCorenflos self-assigned this Sep 1, 2024
@AdrienCorenflos
Copy link
Contributor

Thanks for this. I think I now know how we can use general Metrics here. This requires a small addition to the Metric interface, whereby a metric would now also have a scale function argument, with signature

# Callable[[ArrayLikeTree, Arr], ArrayLikeTree]
def scale(position: ArrayLikeTree, vector: ArrayLikeTree, inv=False) -> ArrayLikeTree:
    mass_matrix = mass_matrix_fn(position)
    # ravel everything etc
    if inv:
        return triangular_solve(mass_matrix_sqrt, vector)  # plus unravelling of course
    return mass_matrix_sqrt @ vector  # plus unravelling of course

@junpenglao would that be ok to add a function like this (essentially scaling a vector by the inv_sqrt_mass defined by the metric at position x).

If this is ok, then we can implement Barker's proposal with any sort of metric, even Riemannian, for cheap:
https://github.com/ismael-mendoza/blackjax/blob/a61af36ab77dc9926ce2b54cdb3d25f6ecb05dd4/blackjax/mcmc/barker.py#L93-L95 would become

z = metric.scale(y, y, True) - metric.scale(x, x, True)  # Be careful here, it may be metric.scale(x, y, True) - metric.scale(y, x, True), need to check the detailed balanced condition properly and same below.
c_x = metric.scale(x, log_x)
c_y = metric.scale(y, log_y)

And something similar here
https://github.com/ismael-mendoza/blackjax/blob/a61af36ab77dc9926ce2b54cdb3d25f6ecb05dd4/blackjax/mcmc/barker.py#L245-L253

There would be some bookkeeping about shapes and all but these seem to be the only change.

@ismael-mendoza @junpenglao what do you think? It;s

@AdrienCorenflos
Copy link
Contributor

This is not even covered in the papers by Livingstone and Zanella: it's a good example of why Blackjax composability is cool :D

tests/mcmc/test_barker.py Outdated Show resolved Hide resolved
blackjax/mcmc/barker.py Outdated Show resolved Hide resolved
blackjax/mcmc/barker.py Outdated Show resolved Hide resolved
@ismael-mendoza
Copy link
Contributor Author

Thanks @AdrienCorenflos that seems like a good idea to me as it avoids the code redundancy and allows us to use Riemannian metrics for free like you pointed out.

On your comment here

z = metric.scale(y, y, True) - metric.scale(x, x, True)  # Be careful here, it may be metric.scale(x, y, True) - metric.scale(y, x, True), need to check the detailed balanced condition properly and same below.

do you have suggestions for how to check the correct detailed balanced condition? I'm not too familiar with this and I assume the Barker paper wouldn't mention it as they only show the algorithm for a fixed mass matrix. Thanks!

@ismael-mendoza
Copy link
Contributor Author

One quick follow up - am I missing something or would you need to implementations of .scale ? one for the gaussian_euclidean and one for the gaussian_riemannian metric? The former one won't take the position?

@AdrienCorenflos
Copy link
Contributor

AdrienCorenflos commented Sep 3, 2024

Thanks @AdrienCorenflos that seems like a good idea to me as it avoids the code redundancy and allows us to use Riemannian metrics for free like you pointed out.

On your comment here

z = metric.scale(y, y, True) - metric.scale(x, x, True)  # Be careful here, it may be metric.scale(x, y, True) - metric.scale(y, x, True), need to check the detailed balanced condition properly and same below.

do you have suggestions for how to check the correct detailed balanced condition? I'm not too familiar with this and I assume the Barker paper wouldn't mention it as they only show the algorithm for a fixed mass matrix. Thanks!

Yes, it's no big deal: essentially, the matrix you used to propose the state needs to be one that ends up in the acceptance:

image

So, here for the proposal we use Ct(x), everywhere,
which means that the acceptance is modified in our case:

image

changes to
$$\alpha(x, y) = \min \Big(1, \frac{\pi(y)}{\pi(x)} \Gamma(x, y)\Big)$$
for
$$\Gamma(x, y) = \prod_{d=1}^D \frac{1 + \exp(-v_i c_i(x))}{1 + \exp(-u_i c_i(y)}$$

where $c(x) = \nabla \log \pi(x) C^{T}(x)$, while
$u = C^{T}(x)^{-1}(y - x)$ and $v = C^{T}(y)^{-1}(x - y)$.
In this case, we indeed have
$$\Gamma(x, y) = \frac{q^{B}(x \mid y)}{q^B(y \mid x)},$$
which is sufficient for the detailed balance to be verified.
The simplification of the $\mu_{\sigma}$ here
image
still works because the mass matrix does not intervene there.
EDIT: Cross this, it's probably not true. I'll have to go to pen and paper, I'll come back with the answer tomorrow. But worst case it's a cheap term to add.

So a bit different from what I had written originally but not too far :) Hopefully no typo/forgotten terms!

@AdrienCorenflos
Copy link
Contributor

One quick follow up - am I missing something or would you need to implementations of .scale ? one for the gaussian_euclidean and one for the gaussian_riemannian metric? The former one won't take the position?

We do need this (or similar) to be added I think, but let's wait for @junpenglao he will likely have a better opinion than me on the right API/design choice for these as he has thought about it a lot more than I have.

@junpenglao
Copy link
Member

One quick follow up - am I missing something or would you need to implementations of .scale ? one for the gaussian_euclidean and one for the gaussian_riemannian metric? The former one won't take the position?

We do need this (or similar) to be added I think, but let's wait for @junpenglao he will likely have a better opinion than me on the right API/design choice for these as he has thought about it a lot more than I have.

Yes sounds like a great addition to metric, the API makes sense as well.

@ismael-mendoza
Copy link
Contributor Author

ismael-mendoza commented Sep 18, 2024

Hello @AdrienCorenflos and @junpenglao,

Thanks @AdrienCorenflos for implemeting the metric scaling. Now I can proceed finishing this PR, I just had two remaining questions.

First, I noticed section 6.1 of the barker paper that the authors use the proposed adaptation of Andrieu and Thoms (2008) for their experiments. I think it would be nice to have the full version of their algorithm blackjax and I wondering if this adaptation has already been implemented somewhere for blackjax? or maybe there is already some clear alternative that should be used to adapt hmc-like algorithms? It's unclear to me if the window_adaptation function that is used commonly in NUTS would be appropriate here for example.

Second, @AdrienCorenflos I was just wondering if you finished the derivation above and if there is some additional term that must be added?

Thank you both for your help!

@AdrienCorenflos
Copy link
Contributor

Oh right, I had completely forgotten about this bit. I'm sure I've got the derivation on a piece of paper somewhere 😬
Let me check tomorrow morning.

@AdrienCorenflos
Copy link
Contributor

Regarding the adaptation bit, let's check when that's done

@AdrienCorenflos
Copy link
Contributor

Actually, I think the easiest is if you implement Barker for the standard fixed mass matrix using this new scale thing and then I'll adapt it to the manifold case

@ismael-mendoza
Copy link
Contributor Author

Hi @AdrienCorenflos I noticed something strange in the _format_covariance fnc that tripped me up when adding the scaling:

if is_inv:
    inv_cov_sqrt = jscipy.linalg.cholesky(cov, lower=True)
    cov_sqrt = jscipy.linalg.solve_triangular(
        inv_cov_sqrt, identity, lower=True, trans=True
    )
else:
    cov_sqrt = jscipy.linalg.cholesky(cov, lower=False).T
    inv_cov_sqrt = jscipy.linalg.solve_triangular(
        cov_sqrt, identity, lower=True, trans=True
    )

am I misunderstanding or the inv_cov_sqrt will be transposed in the case that is_inv = False and not tranposed in the case that is_inv=True? Is that the desired implementation? I was just surprised that it would be inconsistent like that but maybe it has to do with other algorithms?

@AdrienCorenflos
Copy link
Contributor

Yeah we have been discussing this with @junpenglao
It's because the standard Euclidean metric passes in the inverse mass matrix, whereas the Riemannian uses the normal mass matrix... It's a bit inconsistent and we were going to look into making this coherent after this specific PR. For the time being, I'd recommend you make it work for the Euclidean in the new metric API and I'll take care of the Riemannian.

option to transpose the mass_matrix_sqrt or inv_mass_matrix_sqrt was
necessary for the barker algorithm as far as I can tell. This has not
been propagated to the riemannian metric
@AdrienCorenflos
Copy link
Contributor

It all seems to work for the standard pre-conditioning, but I'm seeing a lack of invariance for the Riemannian metric:

image

I'll fix, FYI, it comes from the fact that the proposal is not symmetric in the Gaussian anymore, so you have to compute the pdf

Make acceptance function metric agnostic
Add invariance test
Copy link
Contributor

@AdrienCorenflos AdrienCorenflos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made changes so that the proposal works with Riemannian metrics.
There is one last thing to fix and then we're good and I'll leave it to you @ismael-mendoza

-> the barker_nd sampling function was meant to take fully flat vectors, and here you pass it the metric which acts on trees. You are not seeing issues because your tests do not use nested trees, but this will fail. Can you modify the logic so that it simply does everything using metric on the non-flat vectors? This may require changing a few tests on the function here and there but should be easy

Copy link
Contributor

@AdrienCorenflos AdrienCorenflos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made changes so that the proposal works with Riemannian metrics.
There is one last thing to fix and then we're good and I'll leave it to you @ismael-mendoza

-> the barker_nd sampling function was meant to take fully flat vectors, and here you pass it the metric which acts on trees. You are not seeing issues because your tests do not use nested trees, but this will fail. Can you modify the logic so that it simply does everything using metric on the non-flat vectors? This may require changing a few tests on the function here and there but should be easy

Once this is done, it ships

@ismael-mendoza
Copy link
Contributor Author

Thanks @AdrienCorenflos

an you modify the logic so that it simply does everything using metric on the non-flat vectors?

just to make sure I understand, do you want me to remove metric from _barker_sample_nd entirely and only pass in flat vectors to this function? and the metric logic should live in _barker_sample ?

@AdrienCorenflos
Copy link
Contributor

AdrienCorenflos commented Oct 2, 2024 via email

@ismael-mendoza
Copy link
Contributor Author

oh I think I understand your suggestion now, will modify and commit soon

@ismael-mendoza
Copy link
Contributor Author

@AdrienCorenflos I think I implemented what you had in mind

@AdrienCorenflos
Copy link
Contributor

Thanks, I'll check tomorrow. @junpenglao given I've touched that pr if you could also have a quick glance that would be useful

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor nit, could you try to see if the suggested changes works?

blackjax/mcmc/metrics.py Outdated Show resolved Hide resolved
blackjax/mcmc/metrics.py Outdated Show resolved Hide resolved
ismael-mendoza and others added 4 commits October 4, 2024 11:21
make inv and trans required kwarg with type bool in metric.scale

Co-authored-by: Junpeng Lao <[email protected]>
lax.cond might not be needed in metric.scale as inv and trans are static kwarg

Co-authored-by: Junpeng Lao <[email protected]>
@ismael-mendoza
Copy link
Contributor Author

@junpenglao thanks for the suggestion, I think works but I see that one of the HMC benchmarks had a regression. Although it's unclear to me how that might have happened after just implementing your suggestion

@junpenglao
Copy link
Member

@junpenglao thanks for the suggestion, I think works but I see that one of the HMC benchmarks had a regression. Although it's unclear to me how that might have happened after just implementing your suggestion

Yeah the regression is not related, i will investigate.

@junpenglao
Copy link
Member

@AdrienCorenflos any more comments? Otherwise LGTM.

@AdrienCorenflos
Copy link
Contributor

Nope, I had no additional comments when I asked for your opinion :)

@junpenglao junpenglao merged commit b107f9f into blackjax-devs:main Oct 8, 2024
4 of 5 checks passed
@junpenglao
Copy link
Member

Thank you @ismael-mendoza !!

@ismael-mendoza ismael-mendoza deleted the barker-inverse-mm2 branch October 8, 2024 18:03
aphc14 pushed a commit to aphc14/blackjax that referenced this pull request Oct 19, 2024
* Draft pre-conditioning matrix in Barker proposal.

This is a first draft of adding the pre-conditioning to the Barker
proposal. This follows Algorithms 4 and 5 in Appendix G of the original
Barker proposal paper. It's somewhat unclear from the paper, but the
separate step size that was already implemented serves as a global
scale for the normal distribution of the proposal. The function
`_compute_acceptance_probability` now takes in the transpose sqrt mass
matrix and the inverse, also it has been flattened to accomodate
the corresponding matrix multiplicatios.

* Fix typing of inverse_mass_matrix argument
Fix typing of mass matrix.

* Fix docstrings.

The original docstring of step_size was incorrect, there is no
sympletic integrator.

* Make test for Barker in test_sampling run again

We make this possible by adding an identity pre-conditining matrix,
which should make the test run in the same way as before.

* Add test to ensure correctness of precond matrix

We add a new test to barker.py to ensure that our implementation of
the preconditioning matrix is correct. We follow Appendix G in the
paper that mentions that algorithm 4 and 5 (which we implemented)
should be equivalent to rescaling the parameters and the logdensity
in a specific way. We implement both approaches when using the barker
proposal to infer the mean and sigma of a normal distribution. We
check that with two different random seeds the chains outputted are
equivalent up to some tolerance.

We also patch the original test in this file by adding an identity
mass matrix.

* Fix dimensionality of identity matrix

* Add missing mass matrix in missing tests.

* added option to transpose the matrix when scaling

option to transpose the mass_matrix_sqrt or inv_mass_matrix_sqrt was
necessary for the barker algorithm as far as I can tell. This has not
been propagated to the riemannian metric

* use the metric scaling function in barker

Here we use the new metric.scale function to perform the operations
required by the Barker proposal algorithm, instead of passing around
the mass_matrix_sqrt and inv_mass_matrix_sqrt directly. We also
make the `inverse_mass_matrix` argument optional to avoid breaking
the API.

* update test_sampling with barker api

the mass matrix is now an optional argument in barker.

* update test_barker so it works with metric.scale

* fix tests add trans to scale

* add trans argument to riemannian scaling

* no default

* Update barker.py

Make acceptance function metric agnostic

* Update test_barker.py

Add invariance test

* simplify logic to remove _barker_sample_nd

* fix bug so now everything is tree_mapped in barker

* fix test to not use _barker_sample_nd

* Update blackjax/mcmc/metrics.py

make inv and trans required kwarg with type bool in metric.scale

Co-authored-by: Junpeng Lao <[email protected]>

* Update blackjax/mcmc/metrics.py

lax.cond might not be needed in metric.scale as inv and trans are static kwarg

Co-authored-by: Junpeng Lao <[email protected]>

* propagate changes of inv, trans as required kwarg

* fix test metrics

---------

Co-authored-by: Adrien Corenflos <[email protected]>
Co-authored-by: Junpeng Lao <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants