Skip to content

Commit

Permalink
Don't explicitly construct ILR matrix using JAX
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Jul 19, 2024
1 parent 8a9e14b commit 3e9bdc3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 22 deletions.
37 changes: 17 additions & 20 deletions simplex_transforms/jax/transforms/ilr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,36 @@
import jax.numpy as jnp


def _make_helmert_matrix(N: int):
H1 = jnp.full(N, 1 / jnp.sqrt(N))
H22 = jnp.tril(jnp.ones((N - 1, N - 1)))
def _get_V_mul_y(y):
N = y.shape[-1] + 1
ns = jnp.arange(1, N)
H22 = H22.at[jnp.diag_indices_from(H22)].set(-ns) / jnp.sqrt(ns * (ns + 1)).reshape(
-1, 1
w = y / jnp.sqrt(ns * (ns + 1))
w_rev_sum = jnp.flip(jnp.cumsum(jnp.flip(w, axis=-1), axis=-1), axis=-1)
zeros = jnp.zeros(w_rev_sum.shape[:-1] + (1,))
z = jnp.concatenate([w_rev_sum, zeros], axis=-1) - jnp.concatenate(
[zeros, ns * w], axis=-1
)
H21 = H22[:, :1].at[0].multiply(-1)
H = jnp.block([[H1], [H21, H22]])
return H
return z


def _make_semiorthogonal_matrix(N: int):
H = _make_helmert_matrix(N)
V = H.T[:, 1:]
return V
def _get_V_trans_mul_z(z):
N = z.shape[-1]
ns = jnp.arange(1, N)
y = (jnp.cumsum(z[..., :-1], axis=-1) - ns * z[..., 1:]) / jnp.sqrt(ns * (ns + 1))
return y


class ILR:
def unconstrain(self, x):
N = x.shape[-1]
V = _make_semiorthogonal_matrix(N)
return jnp.dot(jnp.log(x), V)
return _get_V_trans_mul_z(jnp.log(x))

def constrain(self, y):
N = y.shape[-1] + 1
V = _make_semiorthogonal_matrix(N)
return jax.nn.softmax(jnp.dot(y, V.T), axis=-1)
z = _get_V_mul_y(y)
return jax.nn.softmax(z, axis=-1)

def constrain_with_logdetjac(self, y):
N = y.shape[-1] + 1
V = _make_semiorthogonal_matrix(N)
z = jnp.dot(y, V.T)
z = _get_V_mul_y(y)
logx = jax.nn.log_softmax(z, axis=-1)
x = jnp.exp(logx)
logJ = jnp.sum(logx, axis=-1) + jnp.log(N) / 2
Expand Down
26 changes: 24 additions & 2 deletions tests/jax/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,24 @@ def _allclose(x, y, **kwargs):
return jnp.allclose(x, y, **kwargs)


def _make_helmert_matrix(N: int):
H1 = jnp.full(N, 1 / jnp.sqrt(N))
H22 = jnp.tril(jnp.ones((N - 1, N - 1)))
ns = jnp.arange(1, N)
H22 = H22.at[jnp.diag_indices_from(H22)].set(-ns) / jnp.sqrt(ns * (ns + 1)).reshape(
-1, 1
)
H21 = H22[:, :1].at[0].multiply(-1)
H = jnp.block([[H1], [H21, H22]])
return H


def _make_ilr_semiorthogonal_matrix(N: int):
H = _make_helmert_matrix(N)
V = H.T[:, 1:]
return V


def logdetjac(f):
jac = jax.jacobian(f)

Expand Down Expand Up @@ -124,10 +142,14 @@ def test_normalized_transforms_consistent(N, seed=42):


@pytest.mark.parametrize("N", [3, 5, 10])
def test_ilr_semiorthogonal_matrix_properties(N):
def test_ilr_semiorthogonal_matrix_properties(N, seed=87):
from simplex_transforms.jax.transforms import ilr

V = ilr._make_semiorthogonal_matrix(N)
V = _make_ilr_semiorthogonal_matrix(N)
assert V.shape == (N, N - 1)
assert jnp.allclose(V.T @ V, jnp.eye(N - 1))
assert jnp.allclose(V.T @ jnp.ones(N), 0)
y = jax.random.normal(key=jax.random.key(seed), shape=(N - 1,))
z = V @ y
assert jnp.allclose(ilr._get_V_mul_y(y), z)
assert jnp.allclose(ilr._get_V_trans_mul_z(z), y)

0 comments on commit 3e9bdc3

Please sign in to comment.