Skip to content

Commit

Permalink
Merge pull request #96 from ecmerkle/priors
Browse files Browse the repository at this point in the history
add shrink_t option
  • Loading branch information
ecmerkle authored Jan 15, 2025
2 parents f4766a7 + 283e804 commit 6a2bf90
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 24 deletions.
71 changes: 55 additions & 16 deletions R/stanmarg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,18 @@ format_priors <- function(lavpartable, level = 1L) {
lavpartable <- lavpartable[order(lavpartable$col, lavpartable$row),]
}

transtab <- list(c('lambda_y_mn', 'lambda_y_sd', 'len_lam_y'),
c('lambda_x_mn', 'lambda_x_sd', 'len_lam_x'),
c('gamma_mn', 'gamma_sd', 'len_gam'),
c('b_mn', 'b_sd', 'len_b'),
transtab <- list(c('lambda_y_mn', 'lambda_y_sd', 'len_lam_y', 'lambda_y_pri', 'lambda_y_blk'),
c('b_mn', 'b_sd', 'len_b', 'b_pri', 'b_blk'),
c('theta_sd_shape', 'theta_sd_rate', 'len_thet_sd', 'theta_pow'),
c('theta_x_sd_shape', 'theta_x_sd_rate', 'len_thet_x_sd', 'theta_x_pow'),
c('theta_r_alpha', 'theta_r_beta', 'len_thet_r'),
c('theta_x_r_alpha', 'theta_x_r_beta', 'len_thet_x_r'),
c('psi_sd_shape', 'psi_sd_rate', 'len_psi_sd', 'psi_pow'),
c('psi_r_alpha', 'psi_r_beta', 'len_psi_r'),
c('phi_sd_shape', 'phi_sd_rate', 'len_phi_sd', 'phi_pow'),
c('phi_r_alpha', 'phi_r_beta', 'len_phi_r'),
c('nu_mn', 'nu_sd', 'len_nu'),
c('alpha_mn', 'alpha_sd', 'len_alph'),
c('nu_mn', 'nu_sd', 'len_nu', 'nu_pri', 'nu_blk'),
c('alpha_mn', 'alpha_sd', 'len_alph', 'alpha_pri', 'alpha_blk'),
c('tau_mn', 'tau_sd', 'len_tau'))

mats <- c('lambda', 'lambda_x', 'gamma', 'beta', 'thetavar', 'cov.xvar', 'thetaoff',
'cov.xoff', 'psivar', 'psioff', 'phivar', 'phioff', 'nu', 'alpha', 'tau')
mats <- c('lambda', 'beta', 'thetavar', 'thetaoff',
'psivar', 'psioff', 'nu', 'alpha', 'tau')
if (level == 2L) {
newmats <- c('lambda', 'beta', 'thetavar', 'thetaoff', 'psivar', 'psioff', 'nu', 'alpha')
subloc <- match(newmats, mats)
Expand All @@ -100,6 +94,19 @@ format_priors <- function(lavpartable, level = 1L) {
}

out <- list()

## shrinkage priors without <.>
shrpris <- which(grepl("shrink_t", lavpartable$prior) & !grepl("<?>", lavpartable$prior))
if (length(shrpris) > 0) {
lavpartable$prior[shrpris] <- paste0(lavpartable$prior[shrpris], "<999>")
}
## if we have prior blocks specified via <.>, number them for the whole partable
blkpris <- grep("<?>", lavpartable$prior)
blknum <- rep(0, length(lavpartable$prior))
if (length(blkpris) > 0) {
blknum[blkpris] <- as.numeric( as.factor(lavpartable$prior[blkpris]) )
}
lavpartable$blknum <- blknum

for (i in 1:length(mats)) {
mat <- origmat <- mats[i]
Expand All @@ -125,21 +132,29 @@ format_priors <- function(lavpartable, level = 1L) {

prisel <- prisel & (lavpartable$free > 0)
thepris <- lavpartable$prior[prisel]
priblks <- lavpartable$blknum[prisel]
blkmats <- mat %in% c("nu", "lambda", "beta", "alpha")

if (length(thepris) > 0) {
textpris <- thepris[thepris != ""]

prisplit <- strsplit(textpris, "[, ()]+")

param1 <- sapply(prisplit, function(x) x[2])
prinms <- sapply(prisplit, function(x) x[1])

if (!grepl("\\[", prisplit[[1]][3])) {
if (!grepl("\\[", prisplit[[1]][3]) & !blkmats) {
param2 <- sapply(prisplit, function(x) x[3])
if (any(is.na(param2)) & mat == "lvrho") {
## omit lkj here
param1 <- param1[!is.na(param2)]
param2 <- param2[!is.na(param2)]
}
} else if (blkmats) {
pritype <- array(0, length(param1))
pritype[prinms == "shrink_t"] <- 1
param2 <- sapply(prisplit, function(x) x[3])
param2 <- as.numeric(param2)
} else {
param2 <- rep(NA, length(param1))
}
Expand All @@ -163,6 +178,8 @@ format_priors <- function(lavpartable, level = 1L) {
param1 <- array(0, 0)
param2 <- array(0, 0)
powpar <- 1
pritype <- array(0, 0)
priblks <- array(0, 0)
}

out[[ transtab[[i]][1] ]] <- param1
Expand All @@ -172,6 +189,10 @@ format_priors <- function(lavpartable, level = 1L) {
if (origmat %in% c('thetavar', 'cov.xvar', 'psivar', 'phivar')) {
out[[ transtab[[i]][4] ]] <- powpar
}
if (blkmats) {
out[[ transtab[[i]][4] ]] <- pritype
out[[ transtab[[i]][5] ]] <- priblks
}
} # mats

return(out)
Expand All @@ -184,12 +205,18 @@ format_priors <- function(lavpartable, level = 1L) {
# @return nothing
check_priors <- function(lavpartable) {
right_pris <- sapply(dpriors(target = "stan"), function(x) strsplit(x, "[, ()]+")[[1]][1])
## add additional prior options here
new_pri <- rep("shrink_t", 4); names(new_pri) <- c("nu", "alpha", "lambda", "beta")
right_pris <- c(right_pris, new_pri)
pt_pris <- sapply(lavpartable$prior[lavpartable$prior != ""], function(x) strsplit(x, "[, ()]+")[[1]][1])
names(pt_pris) <- lavpartable$mat[lavpartable$prior != ""]
right_pris <- c(right_pris, lvrho = "lkj_corr")
primatch <- match(names(pt_pris), names(right_pris))
badpris <- which(pt_pris != right_pris[primatch])

badpris <- rep(FALSE, length(pt_pris))
for (i in 1:length(pt_pris)) {
badpris[i] <- !(pt_pris[i] %in% right_pris[names(right_pris) == names(pt_pris)[i]])
}
badpris <- which(badpris)
## lvrho entries could also receive beta priors
okpris <- which(names(pt_pris[badpris]) == "lvrho" & pt_pris[badpris] == "beta")
if (length(okpris) > 0) badpris <- badpris[-okpris]
Expand Down Expand Up @@ -441,7 +468,19 @@ stanmarg_data <- function(YX = NULL, S = NULL, YXo = NULL, N, Ng, grpnum, # data
dat <- c(dat, format_priors(lavpartable[lavpartable$level == levlabs[1],]))
dat <- c(dat, format_priors(lavpartable[lavpartable$level == levlabs[2],], level = 2L))
}

allblks <- with(dat, c(lambda_y_blk, b_blk, nu_blk, alpha_blk))
priblks <- table(c(0, allblks))[-1]
dat$npriblks <- length(priblks)
dat$priblklen <- 0
dat$blkparm1 <- array(0, 0)
dat$blkparm2 <- array(0, 0)
if (dat$npriblks > 0) {
dat$priblklen <- max(priblks)
allparm1 <- with(dat, c(lambda_y_mn, b_mn, nu_mn, alpha_mn))
dat$blkparm1 <- array(tapply(allparm1[allblks > 0], allblks[allblks > 0], head, 1), dat$npriblks)
allparm2 <- with(dat, c(lambda_y_sd, b_sd, nu_sd, alpha_sd))
dat$blkparm2 <- array(tapply(allparm2[allblks > 0], allblks[allblks > 0], head, 1), dat$npriblks)
}
return(dat)
}

Expand Down
86 changes: 78 additions & 8 deletions inst/stan/stanmarg.stan
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,68 @@ functions { // you can use these in R following `rstan::expose_stan_functions("f
}

return out;
}
}

real eval_priors(vector Lambda_y_free, vector B_free, vector Nu_free, vector Alpha_free, vector sd0, vector lambda_y_primn, array[] real lambda_y_sd, array[] int lambda_y_pri, array[] int lambda_y_blk, vector b_primn, array[] real b_sd, array[] int b_pri, array[] int b_blk, vector nu_primn, array[] real nu_sd, array[] int nu_pri, array[] int nu_blk, vector alpha_primn, array[] real alpha_sd, array[] int alpha_pri, array[] int alpha_blk, array[] int len_free, vector blkparm1, vector blkparm2, int priblklen, int npriblks) {
real out = 0.0;
if (npriblks == 0) {
out += normal_lpdf(Lambda_y_free | lambda_y_primn, lambda_y_sd);
out += normal_lpdf(B_free | b_primn, b_sd);
out += normal_lpdf(Nu_free | nu_primn, nu_sd);
out += normal_lpdf(Alpha_free | alpha_primn, alpha_sd);
} else {
for (i in 1:npriblks) {
vector[priblklen] parvec;
int npars = 1;
if (len_free[1] > 0) {
for (j in 1:len_free[1]) {
if (lambda_y_pri[j] == 0) {
out += normal_lpdf(Lambda_y_free[j] | lambda_y_primn[j], lambda_y_sd[j]);
} else if (lambda_y_blk[j] == i && lambda_y_pri[j] == 1) {
parvec[npars] = Lambda_y_free[j];
npars += 1;
}
}
}

if (len_free[4] > 0) {
for (j in 1:len_free[4]) {
if (b_pri[j] == 0) {
out += normal_lpdf(B_free[j] | b_primn[j], b_sd[j]);
} else if (b_blk[j] == i && b_pri[j] == 1) {
parvec[npars] = B_free[j];
npars += 1;
}
}
}

if (len_free[13] > 0) {
for (j in 1:len_free[13]) {
if (nu_pri[j] == 0) {
out += normal_lpdf(Nu_free[j] | nu_primn[j], nu_sd[j]);
} else if (nu_blk[j] == i && nu_pri[j] == 1) {
parvec[npars] = Nu_free[j];
npars += 1;
}
}
}

if (len_free[14] > 0) {
for (j in 1:len_free[14]) {
if (alpha_pri[j] == 0) {
out += normal_lpdf(Alpha_free[j] | alpha_primn[j], alpha_sd[j]);
} else if (alpha_blk[j] == i && alpha_pri[j] == 1) {
parvec[npars] = Alpha_free[j];
npars += 1;
}
}
}
out += normal_lpdf(parvec[1:(npars - 1)] | 0, sd0[i]);
out += student_t_lpdf(sd0[i] | blkparm1[i], 0, blkparm2[i]) - log(.5); // left-truncated at 0
}
}
return out;
}
}
data {
// see p. 2 https://books.google.com/books?id=9AC-s50RjacC
Expand Down Expand Up @@ -602,6 +663,10 @@ data {
int<lower=0, upper=1> do_test; // should we do everything in generated quantities?
array[Np] vector[multilev ? p_tilde : p + q - Nord] YXbar; // sample means of continuous manifest variables
array[Np] matrix[multilev ? (p_tilde + 1) : (p + q - Nord + 1), multilev ? (p_tilde + 1) : (p + q - Nord + 1)] S; // sample covariance matrix among all continuous manifest variables NB!! multiply by (N-1) to use wishart lpdf!!
int<lower=0> npriblks; // how many blocks of parameters for prior distributions?
int<lower=0> priblklen; // max number of parameters in a block
vector[npriblks] blkparm1; // parameters of block priors
vector[npriblks] blkparm2;

array[sum(nclus[,2])] int<lower=1> cluster_size; // number of obs per cluster
array[Ng] int<lower=1> ncluster_sizes; // number of unique cluster sizes
Expand Down Expand Up @@ -631,7 +696,6 @@ data {
vector[multilev ? sum(ncluster_sizes) : Ng] log_lik_x; // ll of fixed x variables by unique cluster size
vector[multilev ? sum(nclus[,2]) : Ng] log_lik_x_full; // ll of fixed x variables by cluster


/* sparse matrix representations of skeletons of coefficient matrices,
which is not that interesting but necessary because you cannot pass
missing values into the data block of a Stan program from R */
Expand All @@ -645,6 +709,8 @@ data {
int<lower=0> len_lam_y; // number of free elements minus equality constraints
array[len_lam_y] real lambda_y_mn; // prior
array[len_lam_y] real<lower=0> lambda_y_sd;
array[len_lam_y] int<lower=0> lambda_y_pri;
array[len_lam_y] int<lower=0> lambda_y_blk;

// same things but for B
int<lower=0> len_w4;
Expand All @@ -657,6 +723,8 @@ data {
int<lower=0> len_b;
array[len_b] real b_mn;
array[len_b] real<lower=0> b_sd;
array[len_b] int<lower=0> b_pri;
array[len_b] int<lower=0> b_blk;

// same things but for diag(Theta)
int<lower=0> len_w5;
Expand Down Expand Up @@ -726,6 +794,8 @@ data {
int<lower=0> len_nu;
array[len_nu] real nu_mn;
array[len_nu] real<lower=0> nu_sd;
array[len_nu] int<lower=0> nu_pri;
array[len_nu] int<lower=0> nu_blk;

// same things but for Alpha
int<lower=0> len_w14;
Expand All @@ -738,7 +808,9 @@ data {
int<lower=0> len_alph;
array[len_alph] real alpha_mn;
array[len_alph] real<lower=0> alpha_sd;

array[len_alph] int<lower=0> alpha_pri;
array[len_alph] int<lower=0> alpha_blk;

// same things but for Tau
int<lower=0> len_w15;
array[Ng] int<lower=0> wg15;
Expand Down Expand Up @@ -1172,6 +1244,7 @@ parameters {
vector[len_free[13]] Nu_free;
vector[len_free[14]] Alpha_free;
vector[len_free[15]] Tau_ufree;
vector<lower=0>[npriblks] sd0; // shrink_t parameters

vector<lower=0,upper=1>[Noent] z_aug; //augmented ordinal data
vector[len_free_c[1]] Lambda_y_free_c;
Expand Down Expand Up @@ -1552,11 +1625,8 @@ model { // N.B.: things declared in the model block do not get saved in the outp
}
}

/* prior densities in log-units */
target += normal_lpdf(Lambda_y_free | lambda_y_primn, lambda_y_sd);
target += normal_lpdf(B_free | b_primn, b_sd);
target += normal_lpdf(Nu_free | nu_primn, nu_sd);
target += normal_lpdf(Alpha_free | alpha_primn, alpha_sd);
/* prior densities in log-units, first for unbounded parameters that could have shrinkage priors */
target += eval_priors(Lambda_y_free, B_free, Nu_free, Alpha_free, sd0, lambda_y_primn, lambda_y_sd, lambda_y_pri, lambda_y_blk, b_primn, b_sd, b_pri, b_blk, nu_primn, nu_sd, nu_pri, nu_blk, alpha_primn, alpha_sd, alpha_pri, alpha_blk, len_free, blkparm1, blkparm2, priblklen, npriblks);
target += normal_lpdf(Tau_ufree | tau_primn, tau_sd);

target += normal_lpdf(Lambda_y_free_c | lambda_y_primn_c, lambda_y_sd_c);
Expand Down

0 comments on commit 6a2bf90

Please sign in to comment.