Implementation of some common normalizing flow models allowing for a conditioning context using Jax, Flax, and Distrax. The following are currently implemented:
- Masked/Inverse Autoregressive Flows (MAF/IAF; Papamakarios et al, 2017 and Kingma et al, 2016)
- Neural Spline Flows (NSF; Durkan et al, 2019)
- See notebooks/example.ipynb for a simple usage example.
- See notebooks/sbi.ipynb for an example application for neural simulation-based inference (conditional posterior estimation).
import jax
from models.maf import MaskedAutoregressiveFlow
from models.nsf import NeuralSplineFlow
n_dim = 2 # Feature dim
n_context = 1 # Context dim
## Define flow model
# model = MaskedAutoregressiveFlow(n_dim=n_dim, n_context=n_context, hidden_dims=[128,128], n_transforms=12, activation="tanh", use_random_permutations=False)
model = NeuralSplineFlow(n_dim=n_dim, n_context=n_context, hidden_dims=[128,128], n_transforms=8, activation="gelu", n_bins=4)
## Initialize model and params
key = jax.random.PRNGKey(42)
x_test = jax.random.uniform(key=key, shape=(64, n_dim))
context = jax.random.uniform(key=key, shape=(64, n_context))
params = model.init(key, x_test, context)
## Log-prob and sampling
log_prob = model.apply(params, x_test, jnp.ones((x_test.shape[0], n_context)))
samples = model.apply(params, n_samples, key, jnp.ones((n_samples, n_context)), method=model.sample)