Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extension to handle more than two treatment arms #20

Merged
merged 5 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck, margins=?ignore
extra-packages: any::rcmdcheck
needs: check

- uses: r-lib/actions/check-r-package@v2
Expand Down
7 changes: 1 addition & 6 deletions .github/workflows/pkgdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,9 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::pkgdown, local::., margins=?ignore
extra-packages: any::pkgdown, local::.
needs: website

- name: Install archived dependencies (margins)
run: |
R -e 'install.packages("https://cran.r-project.org/src/contrib/Archive/prediction/prediction_0.3.17.tar.gz")'
R -e 'install.packages("https://cran.r-project.org/src/contrib/Archive/margins/margins_0.3.26.tar.gz")'

- name: Build site
run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE)
shell: Rscript {0}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/rhub.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
with:
token: ${{ secrets.RHUB_TOKEN }}
job-config: ${{ matrix.config.job-config }}
extra-packages: any::rcmdcheck, margins=?ignore
extra-packages: any::rcmdcheck
- uses: r-hub/actions/run-check@v1
with:
token: ${{ secrets.RHUB_TOKEN }}
Expand Down Expand Up @@ -90,7 +90,7 @@ jobs:
with:
job-config: ${{ matrix.config.job-config }}
token: ${{ secrets.RHUB_TOKEN }}
extra-packages: any::rcmdcheck, margins=?ignore
extra-packages: any::rcmdcheck
- uses: r-hub/actions/run-check@v1
with:
job-config: ${{ matrix.config.job-config }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::covr, margins=?ignore
extra-packages: any::covr
needs: coverage

- name: Test coverage
Expand Down
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: beeca
Title: Binary Endpoint Estimation with Covariate Adjustment
Version: 0.1.3.9000
Version: 0.2.0
Authors@R:
c(person(given = "Alex",
family = "Przybylski",
Expand Down Expand Up @@ -33,15 +33,15 @@ Maintainer: Alex Przybylski <[email protected]>
License: LGPL (>= 3)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.2
Suggests:
knitr,
rmarkdown,
testthat (>= 3.0.0),
tidyr,
marginaleffects,
margins,
RobinCar (== 0.3.0)
RobinCar (>= 0.3.0)
Config/testthat/edition: 3
Depends:
R (>= 2.10)
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ importFrom(stats,predict)
importFrom(stats,terms)
importFrom(stats,var)
importFrom(stats,vcov)
importFrom(utils,combn)
importFrom(utils,packageVersion)
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# beeca (development version)
# beeca 0.2.0

- Extensions to allow for more than two treatment arms in the model fit.

# beeca 0.1.3

Expand Down
100 changes: 65 additions & 35 deletions R/apply_contrast.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
#' The choice of contrast affects how treatment effects are calculated and
#' interpreted. Default is `diff`.
#'
#' @param reference a string indicating which treatment group should be considered as
#' the reference level. Accepted values are one of the levels in the treatment
#' variable. Default to the first level used in the `glm` object.
#' @param reference a string or list of strings indicating which treatment
#' group(s) to use as reference level for pairwise comparisons. Accepted values
#' must be a subset of the levels in the treatment variable. Default to the
#' first n-1 treatment levels used in the `glm` object.
#'
#' This parameter influences the calculation of treatment effects
#' relative to the chosen reference group.
Expand Down Expand Up @@ -68,18 +69,19 @@
#' fit3$marginal_est
#' fit3$marginal_se
#'
#' @importFrom utils combn
#' @export
#'
apply_contrast <- function(object, contrast = c("diff", "rr", "or", "logrr", "logor"), reference) {
# assert means are available in object
# Assert means are available in object
if (!"counterfactual.means" %in% names(object)) {
msg <- sprintf(
"Missing counterfactual means. First run `%1$s <- average_predictions(%1$s)`.",
deparse(quote(object))
)
stop(msg, call. = FALSE)
}
# assert varcov is available in object
# Assert varcov is available in object
if (!"robust_varcov" %in% names(object)) {
msg <- sprintf(
"Missing robust varcov. First run `%1$s <- get_varcov(%1$s, ...)`.",
Expand All @@ -92,53 +94,81 @@ apply_contrast <- function(object, contrast = c("diff", "rr", "or", "logrr", "lo
object <- .assert_sanitized(object, trt, warn = T)

data <- .get_data(object)
# assert non-missing reference

# Assert non-missing reference or set default
if (missing(reference)) {
reference <- object$xlevels[[trt]][1]
warning(sprintf("No reference argument was provided, using %s as the reference level", reference), call. = FALSE)
reference <- object$xlevels[[trt]][-nlevels(data[[trt]])]
warning(sprintf("No reference argument was provided, using {%s} as the reference level(s)", paste(reference, collapse = ", ")), call. = FALSE)
}

# Assert reference to be subset of the treatment levels
if (!all(reference %in% levels(data[[trt]]))) {
stop("Reference levels must be a subset of treatment levels : ", paste(levels(data[[trt]]), collapse = ", "), ".")
}

# assert reference to be one of the treatment levels
if (!reference %in% levels(data[[trt]])) {
stop("Reference must be one of : ", paste(levels(data[[trt]]), collapse = ", "), ".")
# Only accept at most nlevels(trt)-1 reference levels
if (length(reference) > (nlevels(data[[trt]])-1)) {
stop(sprintf("Too many reference levels provided. Expecting at most %s value(s) from the possible treatment levels: {%s}",
nlevels(data[[trt]]) - 1,
paste(levels(data[[trt]]), collapse = ", ")))
}

# match contrast argument and retrieve relevant functions
# Match contrast argument and retrieve relevant helper functions
contrast <- match.arg(contrast)
c_est <- get(contrast)
c_str <- get(paste0(contrast, "_str"))
c_grad <- get(paste0("grad_", contrast))

# set reference to 1 if it is the first level
reference_n <- which(levels(data[[trt]]) == reference)
cf_mean_ref <- object$counterfactual.means[reference_n]
cf_mean_inv <- object$counterfactual.means[3 - reference_n]

trt_ref <- reference
trt_inv <- levels(data[[trt]])[-reference_n]
# Extract counterfactual means
cf_means <- object$counterfactual.means

# apply contrast to point estimate
marginal_est <- c_est(cf_mean_inv, cf_mean_ref)
# Get all pairwise arm comparisons for contrasts
# Format according to specified reference levels
l <- levels(data[[trt]])
combos <- combn(l, 2, \(x) {
if (any(reference %in% x)) {
ref <- intersect(reference, x)[[1]]
comp <- x[x!=ref]
c(which(ref == l), which(comp==l))
} else {
c(NA, NA)
}
})
combos <- combos[, is.finite(colSums(combos)), drop=F]
if (length(reference) > 1) {
combos <- combos[,order(combos[1,])]
}
idxs <- matrix(F, nrow=length(l), ncol=length(l))
rownames(idxs) <- colnames(idxs) <- l
idxs[cbind(combos[2,], combos[1,])] <- T

object$marginal_est <- marginal_est
names(object$marginal_est) <- "marginal_est"
attr(object$marginal_est, "reference") <- trt_ref
attr(object$marginal_est, "contrast") <- paste0(contrast, ": ", c_str(trt_inv, trt_ref))

# apply contrast to varcov
gr <- c_grad(cf_mean_inv, cf_mean_ref)
# Get marginal estimates for the contrasts of interest, using correct ref levels
marginal_est <- outer(cf_means, cf_means, c_est)
marginal_est <- marginal_est[idxs]

robust_varcov <- object$robust_varcov
# Define contrast jacobian
gr <- matrix(0, nrow=ncol(combos), ncol=length(l))
# Apply the grad function to all combinations
contrast_values <- t(apply(combos, 2, function(idx) c_grad(cf_means[idx[2]], cf_means[idx[1]])))
gr[cbind(1:ncol(combos), combos[1,])] <- contrast_values[, 1]
gr[cbind(1:ncol(combos), combos[2,])] <- contrast_values[, 2]
# Get marginal standard error for contrasts of interest
marginal_se <- sqrt(diag(gr %*% object$robust_varcov %*% t(gr)))

# correct varcov order based on reference
if (reference_n == 2) diag(robust_varcov) <- rev(diag(robust_varcov))
# Add string description for each comparison
contrast_desc <- paste0(contrast, ": ",
apply(combos, 2, \(x) c_str(l[x[2]], l[x[1]])))
names(marginal_est) <- contrast_desc
names(marginal_se) <- contrast_desc

marginal_se <- sqrt(t(gr) %*% robust_varcov %*% gr)
object$marginal_est <- marginal_est
attr(object$marginal_est, "reference") <- reference
attr(object$marginal_est, "contrast") <- contrast_desc

object$marginal_se <- marginal_se[[1]]
names(object$marginal_se) <- "marginal_se"
attr(object$marginal_se, "reference") <- trt_ref
attr(object$marginal_se, "contrast") <- paste0(contrast, ": ", c_str(trt_inv, trt_ref))
object$marginal_se <- marginal_se
attr(object$marginal_se, "reference") <- reference
attr(object$marginal_se, "contrast") <- contrast_desc
attr(object$marginal_se, "type") <- attr(object$robust_varcov, "type")

return(object)
Expand Down
Loading
Loading