-
Notifications
You must be signed in to change notification settings - Fork 108
Meeting minutes
Present:
- Junpeng Lao
- Adrien Corenflos
- Rémi Louf
Pausing for now.
Users who will build their own matrix:
- Want to learn from the library or teach with the library; Pass their own mass matrix to show how the matrix impacts sampling.
- Researchers working on new algorithms (mostly adaptation); will do by definition. Need a way to provide a matrix at a high level.
- Applied researchers who already have quadratic approximation of their potential.
- Adrien: Design document
Junpeng started working in it here. We need to add a description of the high-level code architecture.
`new_state, info = nuts.one_step(rng_key, state, step_size, inverse_mass_matrix)`
Q: What can we dowith this design that we could not? A: Modify the kernel between two HMC steps; adaptation is currently slower than Numpyro’s which is not cool. It’s also a lot of partial applications in the code which makes it look less elegant. And now we can `vmap` accros array of step sizes & mass matrix -> ChEes
Q: Concerns? A: With more parameters it will get more (too?) complicated.
Q: Can we somehow hide that to the users at a higher level? Code structure?
hmc(logprob_fn, *, step_size, inverse_mass_matrix) # public API
if not step_size and not inverse_mass_matrix:
ft.partial(one_step, step_size=step_size, inverse_mass_matrix=inverse_mass_matrix)
Q: Can we use NamedTuple for parameters?
class HMCParameters(NamedTuple):
num_integration_steps: Optional[int]
step_size: Optional[float]
inverse_mass_matrix: Optional[jnp.ndarray]
params = HMCParameters(10)
### IS THAT OK?
kernel = blackjax.hmc(logprob_fn) ## function unpacks
kernel.update(rng_key, init_state, params) ## value of making it more complicated?
# !!! We cannot jit it from the outside here, you cannot jit
# !!! For `vmap` we would have to duplicate the elements we're not mapping over
Actually we can add that on top of the `one_step` function that only takes `step_size`, `inverse_mass_matrix`.
import blackjax
class SamplingKernel(NamedTuple):
init: Callable
step: Callable
# Now
init, step = bjx.hmc(logprob_fn)
state = init(position)
new_state, info = step(rng_key, state)
# With Optax design
hmc = blackjax.hmc(logprob_fn)
state = hmc.init(position)
new_state, info = hmc.step(rng_key, state)
Q: What should we name the second function? Q: Are there other places in the internals where we can do that?
Decision: Ok, proceed.
Not only would help harmonize the API with the introduction of SgLD, but also something that people actually use (see #136)
hmc = blackjax.hmc(logprob_fn: Callable, logprob_grad_fn: Optional[Callable])
if `logprob_grad_fn` is `None` then `jax.grad` is used to compute the gradient.
Decision: Not for now.
Adrien: State-space models shouldn’t be in blackjax. Can be quickly complicated to handle.
Remi:
- Variational inference is fine is someone wants to implement it.
- Bayesian updating is missing in most PPLs. It is currently possible to do with SMC, and we should add an example.
Junpeng:
- HMC-focused, then SMC, more than inference focused
Or scripts (docstrings) ----> Sphinx (see Python Optimal Transport) However users should be able to easily convert them to notebooks; very often people start tinkering with a library from example notebooks.
Q: Can we export those as notebooks *easily*? A: look at what matplotlib does
Decision: Use scripts if it can be easily converted to notebooks for users to tinker with; otherwise jupytext.
- API doc (compiler)
- Rappeler le scope
- Rappeler le public
- Design principle, code map
- Contributing guide
- If there’s a question about how to do stuff it should probably be raised as an issue.
Decision: We should just do it!
Decision: Rémi investigates.
Adrien:
- Design document inverse mass matrix
Junpeng:
- PR on `chex` for testing
Rémi
- SgLD design and implementation
- Adaptation