From 3e9bdc39216141d4c00430a8ca76946b52de7f5c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 19 Jul 2024 17:12:15 +0200 Subject: [PATCH] Don't explicitly construct ILR matrix using JAX --- simplex_transforms/jax/transforms/ilr.py | 37 +++++++++++------------- tests/jax/test_transforms.py | 26 +++++++++++++++-- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/simplex_transforms/jax/transforms/ilr.py b/simplex_transforms/jax/transforms/ilr.py index 83e03dd..d8578c3 100644 --- a/simplex_transforms/jax/transforms/ilr.py +++ b/simplex_transforms/jax/transforms/ilr.py @@ -2,39 +2,36 @@ import jax.numpy as jnp -def _make_helmert_matrix(N: int): - H1 = jnp.full(N, 1 / jnp.sqrt(N)) - H22 = jnp.tril(jnp.ones((N - 1, N - 1))) +def _get_V_mul_y(y): + N = y.shape[-1] + 1 ns = jnp.arange(1, N) - H22 = H22.at[jnp.diag_indices_from(H22)].set(-ns) / jnp.sqrt(ns * (ns + 1)).reshape( - -1, 1 + w = y / jnp.sqrt(ns * (ns + 1)) + w_rev_sum = jnp.flip(jnp.cumsum(jnp.flip(w, axis=-1), axis=-1), axis=-1) + zeros = jnp.zeros(w_rev_sum.shape[:-1] + (1,)) + z = jnp.concatenate([w_rev_sum, zeros], axis=-1) - jnp.concatenate( + [zeros, ns * w], axis=-1 ) - H21 = H22[:, :1].at[0].multiply(-1) - H = jnp.block([[H1], [H21, H22]]) - return H + return z -def _make_semiorthogonal_matrix(N: int): - H = _make_helmert_matrix(N) - V = H.T[:, 1:] - return V +def _get_V_trans_mul_z(z): + N = z.shape[-1] + ns = jnp.arange(1, N) + y = (jnp.cumsum(z[..., :-1], axis=-1) - ns * z[..., 1:]) / jnp.sqrt(ns * (ns + 1)) + return y class ILR: def unconstrain(self, x): - N = x.shape[-1] - V = _make_semiorthogonal_matrix(N) - return jnp.dot(jnp.log(x), V) + return _get_V_trans_mul_z(jnp.log(x)) def constrain(self, y): - N = y.shape[-1] + 1 - V = _make_semiorthogonal_matrix(N) - return jax.nn.softmax(jnp.dot(y, V.T), axis=-1) + z = _get_V_mul_y(y) + return jax.nn.softmax(z, axis=-1) def constrain_with_logdetjac(self, y): N = y.shape[-1] + 1 - V = _make_semiorthogonal_matrix(N) - z = jnp.dot(y, V.T) + z = _get_V_mul_y(y) logx = jax.nn.log_softmax(z, axis=-1) x = jnp.exp(logx) logJ = jnp.sum(logx, axis=-1) + jnp.log(N) / 2 diff --git a/tests/jax/test_transforms.py b/tests/jax/test_transforms.py index f064ff3..209f529 100644 --- a/tests/jax/test_transforms.py +++ b/tests/jax/test_transforms.py @@ -45,6 +45,24 @@ def _allclose(x, y, **kwargs): return jnp.allclose(x, y, **kwargs) +def _make_helmert_matrix(N: int): + H1 = jnp.full(N, 1 / jnp.sqrt(N)) + H22 = jnp.tril(jnp.ones((N - 1, N - 1))) + ns = jnp.arange(1, N) + H22 = H22.at[jnp.diag_indices_from(H22)].set(-ns) / jnp.sqrt(ns * (ns + 1)).reshape( + -1, 1 + ) + H21 = H22[:, :1].at[0].multiply(-1) + H = jnp.block([[H1], [H21, H22]]) + return H + + +def _make_ilr_semiorthogonal_matrix(N: int): + H = _make_helmert_matrix(N) + V = H.T[:, 1:] + return V + + def logdetjac(f): jac = jax.jacobian(f) @@ -124,10 +142,14 @@ def test_normalized_transforms_consistent(N, seed=42): @pytest.mark.parametrize("N", [3, 5, 10]) -def test_ilr_semiorthogonal_matrix_properties(N): +def test_ilr_semiorthogonal_matrix_properties(N, seed=87): from simplex_transforms.jax.transforms import ilr - V = ilr._make_semiorthogonal_matrix(N) + V = _make_ilr_semiorthogonal_matrix(N) assert V.shape == (N, N - 1) assert jnp.allclose(V.T @ V, jnp.eye(N - 1)) assert jnp.allclose(V.T @ jnp.ones(N), 0) + y = jax.random.normal(key=jax.random.key(seed), shape=(N - 1,)) + z = V @ y + assert jnp.allclose(ilr._get_V_mul_y(y), z) + assert jnp.allclose(ilr._get_V_trans_mul_z(z), y)