Skip to content

Commit

Permalink
add tests for dirichlet_multinomial distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-peatman committed Jan 24, 2025
1 parent 031980e commit 85bd4fb
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 2 deletions.
25 changes: 25 additions & 0 deletions tests/local/tests.models-4.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,31 @@ test_that("multinomial models work correctly", suppressWarnings({
expect_ggplot(plot(ce, ask = FALSE)[[1]])
}))

test_that("dirichlet_multinomial models work correctly", suppressWarnings({
require("extraDistr")
set.seed(1245)
N <- 100
dat <- as.data.frame(extraDistr::rdirmnom(N, 10, c(10, 5, 1)))
names(dat) <- paste0("y", 1:3)
dat$size <- with(dat, y1 + y2 + y3)
dat$x <- rnorm(N)
dat$y <- with(dat, cbind(y1, y2, y3))

fit <- brm(
y | trials(size) ~ x, data = dat,
family = dirichlet_multinomial(),
prior = prior("exponential(0.01)", "phi")
)
print(summary(fit))
pred <- predict(fit)
expect_equal(dim(pred), c(nobs(fit), 4, 3))
expect_equal(dimnames(pred)[[3]], c("y1", "y2", "y3"))
waic <- waic(fit)
expect_range(waic$estimates[3, 1], 550, 650)
ce <- conditional_effects(fit, categorical = TRUE)
expect_ggplot(plot(ce, ask = FALSE)[[1]])
}))

test_that("dirichlet models work correctly", suppressWarnings({
set.seed(1246)
N <- 100
Expand Down
6 changes: 6 additions & 0 deletions tests/testthat/tests.log_lik.R
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,12 @@ test_that("log_lik for categorical and related models runs without erros", {
ll <- sapply(1:nobs, brms:::log_lik_multinomial, prep = prep)
expect_equal(dim(ll), c(ns, nobs))

prep$data$trials <- sample(1:20, nobs)
prep$dpars$phi <- rexp(ns, 10)
prep$family <- dirichlet_multinomial()
ll <- sapply(1:nobs, brms:::log_lik_dirichlet_multinomial, prep = prep)
expect_equal(dim(ll), c(ns, nobs))

prep$data$Y <- prep$data$Y / rowSums(prep$data$Y)
prep$dpars$phi <- rexp(ns, 10)
prep$family <- dirichlet()
Expand Down
6 changes: 5 additions & 1 deletion tests/testthat/tests.posterior_epred.R
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ test_that("posterior_epred for advanced count data distributions runs without er
expect_equal(dim(pred), c(ns, nobs))
})

test_that("posterior_epred for multinomial and dirichlet models runs without errors", {
test_that("posterior_epred for multinomial, dirichlet_multinomial and dirichlet models runs without errors", {
ns <- 15
nobs <- 8
ncat <- 3
Expand All @@ -198,6 +198,10 @@ test_that("posterior_epred for multinomial and dirichlet models runs without err
pred <- brms:::posterior_epred_multinomial(prep = prep)
expect_equal(dim(pred), c(ns, nobs, ncat))

prep$family <- dirichlet_multinomial()
pred <- brms:::posterior_epred_dirichlet_multinomial(prep = prep)
expect_equal(dim(pred), c(ns, nobs, ncat))

prep$family <- dirichlet()
pred <- brms:::posterior_epred_dirichlet(prep = prep)
expect_equal(dim(pred), c(ns, nobs, ncat))
Expand Down
6 changes: 6 additions & 0 deletions tests/testthat/tests.posterior_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ test_that("posterior_predict for categorical and related models runs without err
pred <- brms:::posterior_predict_multinomial(i = sample(1:nobs, 1), prep = prep)
expect_equal(dim(pred), c(ns, ncat))

prep$data$trials <- sample(1:20, nobs)
prep$dpars$phi <- rexp(ns, 1)
prep$family <- dirichlet_multinomial()
pred <- brms:::posterior_predict_dirichlet_multinomial(i = sample(1:nobs, 1), prep = prep)
expect_equal(dim(pred), c(ns, ncat))

prep$dpars$phi <- rexp(ns, 1)
prep$family <- dirichlet()
pred <- brms:::posterior_predict_dirichlet(i = sample(1:nobs, 1), prep = prep)
Expand Down
23 changes: 23 additions & 0 deletions tests/testthat/tests.stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,29 @@ test_that("Stan code for multinomial models is correct", {
expect_match2(scode, "lprior += normal_lpdf(Intercept_muy3 | 0, 2);")
})

test_that("Stan code for dirichlet_multinomial models is correct", {
N <- 15
dat <- data.frame(
y1 = rbinom(N, 10, 0.3), y2 = rbinom(N, 10, 0.5),
y3 = rbinom(N, 10, 0.7), x = rnorm(N)
)
dat$size <- with(dat, y1 + y2 + y3)
dat$y <- with(dat, cbind(y1, y2, y3))
prior <- prior(normal(0, 10), "b", dpar = muy2) +
prior(cauchy(0, 1), "Intercept", dpar = muy2) +
prior(normal(0, 2), "Intercept", dpar = muy3) +
prior(exponential(10), "phi")
scode <- stancode(bf(y | trials(size) ~ 1, muy2 ~ x), data = dat,
family = dirichlet_multinomial(), prior = prior)
expect_match2(scode, "array[N, ncat] int Y;")
expect_match2(scode, "target += dirichlet_multinomial_logit2_lpmf(Y[n] | mu[n], phi);")
expect_match2(scode, "muy2 += Intercept_muy2 + Xc_muy2 * b_muy2;")
expect_match2(scode, "lprior += normal_lpdf(b_muy2 | 0, 10);")
expect_match2(scode, "lprior += cauchy_lpdf(Intercept_muy2 | 0, 1);")
expect_match2(scode, "lprior += normal_lpdf(Intercept_muy3 | 0, 2);")
expect_match2(scode, "lprior += exponential_lpdf(phi | 10);")
})

test_that("Stan code for dirichlet models is correct", {
N <- 15
dat <- as.data.frame(rdirichlet(N, c(3, 2, 1)))
Expand Down
11 changes: 10 additions & 1 deletion tests/testthat/tests.standata.R
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,7 @@ test_that("reserved variables 'Intercept' is handled correctly", {
expect_true(all(sdata$X[, "Intercept"] == 1))
})

test_that("data for multinomial and dirichlet models is correct", {
test_that("data for multinomial, dirichlet_multinomial and dirichlet models is correct", {
N <- 15
dat <- as.data.frame(rdirichlet(N, c(3, 2, 1)))
names(dat) <- c("y1", "y2", "y3")
Expand All @@ -993,6 +993,11 @@ test_that("data for multinomial and dirichlet models is correct", {
expect_equal(sdata$ncat, 3)
expect_equal(sdata$Y, unname(dat$t))

sdata <- standata(t | trials(size) ~ x, dat, dirichlet_multinomial())
expect_equal(sdata$trials, as.array(dat$size))
expect_equal(sdata$ncat, 3)
expect_equal(sdata$Y, unname(dat$t))

sdata <- standata(y ~ x, data = dat, family = dirichlet())
expect_equal(sdata$ncat, 3)
expect_equal(sdata$Y, unname(dat$y))
Expand All @@ -1001,6 +1006,10 @@ test_that("data for multinomial and dirichlet models is correct", {
standata(t | trials(10) ~ x, data = dat, family = multinomial()),
"Number of trials does not match the number of events"
)
expect_error(
standata(t | trials(10) ~ x, data = dat, family = dirichlet_multinomial()),
"Number of trials does not match the number of events"
)
expect_error(standata(t ~ x, data = dat, family = dirichlet()),
"Response values in simplex models must sum to 1")
})
Expand Down

0 comments on commit 85bd4fb

Please sign in to comment.