Skip to content

Commit

Permalink
Species-specific sigma
Browse files Browse the repository at this point in the history
  • Loading branch information
fseaton committed Feb 12, 2024
1 parent 645d05e commit 6f98897
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
8 changes: 4 additions & 4 deletions R/jsdm_stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,10 @@ ifelse(site_intercept == "grouped",
matrix[D, N] LV_uncor; // Per-site latent variable"
var_pars <- switch(family,
"gaussian" = "
real<lower=0> sigma; // Gaussian parameters",
real<lower=0> sigma[S]; // Gaussian parameters",
"bernoulli" = "",
"neg_binomial" = "
real<lower=0> kappa; // neg_binomial parameters",
real<lower=0> kappa[S]; // neg_binomial parameters",
"poisson" = ""
)

Expand Down Expand Up @@ -325,9 +325,9 @@ ifelse(site_intercept == "grouped",
for(j in 1:S) {
log_lik[i, j] = ",
switch(family,
"gaussian" = "normal_lpdf(Y[i, j] | linpred[i, j], sigma);",
"gaussian" = "normal_lpdf(Y[i, j] | linpred[i, j], sigma[j]);",
"bernoulli" = "bernoulli_logit_lpmf(Y[i, j] | linpred[i, j]);",
"neg_binomial" = "neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa);",
"neg_binomial" = "neg_binomial_2_log_lpmf(Y[i, j] | linpred[i, j], kappa[j]);",
"poisson" = "poisson_log_lpmf(Y[i, j] | linpred[i, j]);",
"binomial" = "binomial_logit_lpmf(Y[i, j] | Ntrials[i], linpred[i, j]);"
),"
Expand Down
18 changes: 10 additions & 8 deletions R/posterior_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -241,14 +241,16 @@ posterior_predict.jsdmStanFit <- function(object, newdata = NULL,
}
}
} else {
x2 <- apply(x2, 1:2, function(x) {
switch(object$family,
"gaussian" = stats::rnorm(1, x, mod_sigma),
"bernoulli" = stats::rbinom(1, 1, x),
"poisson" = stats::rpois(1, x),
"neg_binomial" = rgampois(1, x, mod_kappa)
)
})
for(i in seq_len(nrow(x2))){
for(j in seq_len(ncol(x2))){
x2[i,j] <- switch(object$family,
"gaussian" = stats::rnorm(1, x2[i,j], mod_sigma[j]),
"bernoulli" = stats::rbinom(1, 1, x2[i,j]),
"poisson" = stats::rpois(1, x2[i,j]),
"neg_binomial" = rgampois(1, x2[i,j], mod_kappa[j])
)
}
}
}
x2
})
Expand Down
14 changes: 10 additions & 4 deletions R/sim_data_funs.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
"LV" = D1 * N,
"L" = D1 * (S - D1) + (D1 * (D1 - 1) / 2) + D1,
"sigma_L" = 1,
"sigma" = 1,
"kappa" = 1
"sigma" = S,
"kappa" = S
)
fun_args <- as.list(c(fun_arg1, as.numeric(unlist(y[[1]][[1]])[-1])))

Expand Down Expand Up @@ -311,9 +311,9 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
Y[i, j] <- switch(response,
"neg_binomial" = rgampois(1,
mu = exp(mu_ij),
scale = kappa
scale = kappa[j]
),
"gaussian" = stats::rnorm(1, mu_ij, sigma),
"gaussian" = stats::rnorm(1, mu_ij, sigma[j]),
"poisson" = stats::rpois(1, exp(mu_ij)),
"bernoulli" = stats::rbinom(1, 1, inv_logit(mu_ij)),
"binomial" = stats::rbinom(1, Ntrials[i], inv_logit(mu_ij))
Expand Down Expand Up @@ -356,6 +356,12 @@ jsdm_sim_data <- function(N, S, D = NULL, K = 0L, family, method = c("gllvm", "m
if (response == "neg_binomial") {
pars$kappa <- kappa
}
if(response == "gaussian"){
pars$sigma <- sigma
}
if(response == "neg_binomial"){
pars$kappa <- kappa
}
if (isTRUE(species_intercept)) {
if (K > 0) {
x <- x[, 2:ncol(x)]
Expand Down

0 comments on commit 6f98897

Please sign in to comment.