Skip to content

Commit

Permalink
Don't explicitly construct the ILR matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Jul 18, 2024
1 parent f46a1d9 commit 8a9e14b
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 24 deletions.
24 changes: 8 additions & 16 deletions simplex_transforms/stan/transforms/ILR/ILR_functions.stan
Original file line number Diff line number Diff line change
@@ -1,28 +1,20 @@
matrix semiorthogonal_matrix(data int N) {
matrix[N, N - 1] V;
real inv_nrm2;
for (n in 1 : (N - 1)) {
inv_nrm2 = inv_sqrt(n * (n + 1));
V[1 : n, n] = rep_vector(inv_nrm2, n);
V[n + 1, n] = -n * inv_nrm2;
V[(n + 2) : N, n] = rep_vector(0, N - n - 1);
}
return V;
}

vector inv_ilr_simplex_constrain_lp(vector y, data matrix V) {
vector inv_ilr_simplex_constrain_lp(vector y) {
int N = rows(y) + 1;
vector[N] z = V * y;
vector[N - 1] ns = linspaced_vector(N - 1, 1, N - 1);
vector[N - 1] w = y ./ sqrt(ns .* (ns + 1));
vector[N] z = append_row(reverse(cumulative_sum(reverse(w))), 0) - append_row(0, ns .* w);
real r = log_sum_exp(z);
vector[N] x = exp(z - r);
target += 0.5 * log(N);
target += sum(z) - N * r;
return x;
}

vector inv_ilr_log_simplex_constrain_lp(vector y, data matrix V) {
vector inv_ilr_log_simplex_constrain_lp(vector y) {
int N = rows(y) + 1;
vector[N] z = V * y;
vector[N - 1] ns = linspaced_vector(N - 1, 1, N - 1);
vector[N - 1] w = y ./ sqrt(ns .* (ns + 1));
vector[N] z = append_row(reverse(cumulative_sum(reverse(w))), 0) - append_row(0, ns .* w);
real r = log_sum_exp(z);
vector[N] log_x = z - r;
target += 0.5 * log(N);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
transformed data {
matrix[N, N - 1] V = semiorthogonal_matrix(N);
}
parameters {
vector[N - 1] y;
}
transformed parameters {
vector<upper=0>[N] log_x = inv_ilr_log_simplex_constrain_lp(y, V);;
vector<upper=0>[N] log_x = inv_ilr_log_simplex_constrain_lp(y);
simplex[N] x = exp(log_x);
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
transformed data {
matrix[N, N - 1] V = semiorthogonal_matrix(N);
}
parameters {
vector[N - 1] y;
}
transformed parameters {
simplex[N] x = inv_ilr_simplex_constrain_lp(y, V);
simplex[N] x = inv_ilr_simplex_constrain_lp(y);
}

0 comments on commit 8a9e14b

Please sign in to comment.