Skip to content

Commit

Permalink
extension to handle more than two treatment arms (#20)
Browse files Browse the repository at this point in the history
* extension to handle more than two treatment arms

* update robincar suggests dependency

* fix note on import

* update github workflows

* increment version
  • Loading branch information
przybal2 authored Nov 12, 2024
1 parent 6c269af commit 425a5ad
Show file tree
Hide file tree
Showing 23 changed files with 643 additions and 247 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck, margins=?ignore
extra-packages: any::rcmdcheck
needs: check

- uses: r-lib/actions/check-r-package@v2
Expand Down
7 changes: 1 addition & 6 deletions .github/workflows/pkgdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,9 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::pkgdown, local::., margins=?ignore
extra-packages: any::pkgdown, local::.
needs: website

- name: Install archived dependencies (margins)
run: |
R -e 'install.packages("https://cran.r-project.org/src/contrib/Archive/prediction/prediction_0.3.17.tar.gz")'
R -e 'install.packages("https://cran.r-project.org/src/contrib/Archive/margins/margins_0.3.26.tar.gz")'
- name: Build site
run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE)
shell: Rscript {0}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/rhub.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
with:
token: ${{ secrets.RHUB_TOKEN }}
job-config: ${{ matrix.config.job-config }}
extra-packages: any::rcmdcheck, margins=?ignore
extra-packages: any::rcmdcheck
- uses: r-hub/actions/run-check@v1
with:
token: ${{ secrets.RHUB_TOKEN }}
Expand Down Expand Up @@ -90,7 +90,7 @@ jobs:
with:
job-config: ${{ matrix.config.job-config }}
token: ${{ secrets.RHUB_TOKEN }}
extra-packages: any::rcmdcheck, margins=?ignore
extra-packages: any::rcmdcheck
- uses: r-hub/actions/run-check@v1
with:
job-config: ${{ matrix.config.job-config }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-coverage.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::covr, margins=?ignore
extra-packages: any::covr
needs: coverage

- name: Test coverage
Expand Down
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: beeca
Title: Binary Endpoint Estimation with Covariate Adjustment
Version: 0.1.3.9000
Version: 0.2.0
Authors@R:
c(person(given = "Alex",
family = "Przybylski",
Expand Down Expand Up @@ -33,15 +33,15 @@ Maintainer: Alex Przybylski <[email protected]>
License: LGPL (>= 3)
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.2
Suggests:
knitr,
rmarkdown,
testthat (>= 3.0.0),
tidyr,
marginaleffects,
margins,
RobinCar (== 0.3.0)
RobinCar (>= 0.3.0)
Config/testthat/edition: 3
Depends:
R (>= 2.10)
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ importFrom(stats,predict)
importFrom(stats,terms)
importFrom(stats,var)
importFrom(stats,vcov)
importFrom(utils,combn)
importFrom(utils,packageVersion)
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# beeca (development version)
# beeca 0.2.0

- Extensions to allow for more than two treatment arms in the model fit.

# beeca 0.1.3

Expand Down
100 changes: 65 additions & 35 deletions R/apply_contrast.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
#' The choice of contrast affects how treatment effects are calculated and
#' interpreted. Default is `diff`.
#'
#' @param reference a string indicating which treatment group should be considered as
#' the reference level. Accepted values are one of the levels in the treatment
#' variable. Default to the first level used in the `glm` object.
#' @param reference a string or list of strings indicating which treatment
#' group(s) to use as reference level for pairwise comparisons. Accepted values
#' must be a subset of the levels in the treatment variable. Default to the
#' first n-1 treatment levels used in the `glm` object.
#'
#' This parameter influences the calculation of treatment effects
#' relative to the chosen reference group.
Expand Down Expand Up @@ -68,18 +69,19 @@
#' fit3$marginal_est
#' fit3$marginal_se
#'
#' @importFrom utils combn
#' @export
#'
apply_contrast <- function(object, contrast = c("diff", "rr", "or", "logrr", "logor"), reference) {
# assert means are available in object
# Assert means are available in object
if (!"counterfactual.means" %in% names(object)) {
msg <- sprintf(
"Missing counterfactual means. First run `%1$s <- average_predictions(%1$s)`.",
deparse(quote(object))
)
stop(msg, call. = FALSE)
}
# assert varcov is available in object
# Assert varcov is available in object
if (!"robust_varcov" %in% names(object)) {
msg <- sprintf(
"Missing robust varcov. First run `%1$s <- get_varcov(%1$s, ...)`.",
Expand All @@ -92,53 +94,81 @@ apply_contrast <- function(object, contrast = c("diff", "rr", "or", "logrr", "lo
object <- .assert_sanitized(object, trt, warn = T)

data <- .get_data(object)
# assert non-missing reference

# Assert non-missing reference or set default
if (missing(reference)) {
reference <- object$xlevels[[trt]][1]
warning(sprintf("No reference argument was provided, using %s as the reference level", reference), call. = FALSE)
reference <- object$xlevels[[trt]][-nlevels(data[[trt]])]
warning(sprintf("No reference argument was provided, using {%s} as the reference level(s)", paste(reference, collapse = ", ")), call. = FALSE)
}

# Assert reference to be subset of the treatment levels
if (!all(reference %in% levels(data[[trt]]))) {
stop("Reference levels must be a subset of treatment levels : ", paste(levels(data[[trt]]), collapse = ", "), ".")
}

# assert reference to be one of the treatment levels
if (!reference %in% levels(data[[trt]])) {
stop("Reference must be one of : ", paste(levels(data[[trt]]), collapse = ", "), ".")
# Only accept at most nlevels(trt)-1 reference levels
if (length(reference) > (nlevels(data[[trt]])-1)) {
stop(sprintf("Too many reference levels provided. Expecting at most %s value(s) from the possible treatment levels: {%s}",
nlevels(data[[trt]]) - 1,
paste(levels(data[[trt]]), collapse = ", ")))
}

# match contrast argument and retrieve relevant functions
# Match contrast argument and retrieve relevant helper functions
contrast <- match.arg(contrast)
c_est <- get(contrast)
c_str <- get(paste0(contrast, "_str"))
c_grad <- get(paste0("grad_", contrast))

# set reference to 1 if it is the first level
reference_n <- which(levels(data[[trt]]) == reference)
cf_mean_ref <- object$counterfactual.means[reference_n]
cf_mean_inv <- object$counterfactual.means[3 - reference_n]

trt_ref <- reference
trt_inv <- levels(data[[trt]])[-reference_n]
# Extract counterfactual means
cf_means <- object$counterfactual.means

# apply contrast to point estimate
marginal_est <- c_est(cf_mean_inv, cf_mean_ref)
# Get all pairwise arm comparisons for contrasts
# Format according to specified reference levels
l <- levels(data[[trt]])
combos <- combn(l, 2, \(x) {
if (any(reference %in% x)) {
ref <- intersect(reference, x)[[1]]
comp <- x[x!=ref]
c(which(ref == l), which(comp==l))
} else {
c(NA, NA)
}
})
combos <- combos[, is.finite(colSums(combos)), drop=F]
if (length(reference) > 1) {
combos <- combos[,order(combos[1,])]
}
idxs <- matrix(F, nrow=length(l), ncol=length(l))
rownames(idxs) <- colnames(idxs) <- l
idxs[cbind(combos[2,], combos[1,])] <- T

object$marginal_est <- marginal_est
names(object$marginal_est) <- "marginal_est"
attr(object$marginal_est, "reference") <- trt_ref
attr(object$marginal_est, "contrast") <- paste0(contrast, ": ", c_str(trt_inv, trt_ref))

# apply contrast to varcov
gr <- c_grad(cf_mean_inv, cf_mean_ref)
# Get marginal estimates for the contrasts of interest, using correct ref levels
marginal_est <- outer(cf_means, cf_means, c_est)
marginal_est <- marginal_est[idxs]

robust_varcov <- object$robust_varcov
# Define contrast jacobian
gr <- matrix(0, nrow=ncol(combos), ncol=length(l))
# Apply the grad function to all combinations
contrast_values <- t(apply(combos, 2, function(idx) c_grad(cf_means[idx[2]], cf_means[idx[1]])))
gr[cbind(1:ncol(combos), combos[1,])] <- contrast_values[, 1]
gr[cbind(1:ncol(combos), combos[2,])] <- contrast_values[, 2]
# Get marginal standard error for contrasts of interest
marginal_se <- sqrt(diag(gr %*% object$robust_varcov %*% t(gr)))

# correct varcov order based on reference
if (reference_n == 2) diag(robust_varcov) <- rev(diag(robust_varcov))
# Add string description for each comparison
contrast_desc <- paste0(contrast, ": ",
apply(combos, 2, \(x) c_str(l[x[2]], l[x[1]])))
names(marginal_est) <- contrast_desc
names(marginal_se) <- contrast_desc

marginal_se <- sqrt(t(gr) %*% robust_varcov %*% gr)
object$marginal_est <- marginal_est
attr(object$marginal_est, "reference") <- reference
attr(object$marginal_est, "contrast") <- contrast_desc

object$marginal_se <- marginal_se[[1]]
names(object$marginal_se) <- "marginal_se"
attr(object$marginal_se, "reference") <- trt_ref
attr(object$marginal_se, "contrast") <- paste0(contrast, ": ", c_str(trt_inv, trt_ref))
object$marginal_se <- marginal_se
attr(object$marginal_se, "reference") <- reference
attr(object$marginal_se, "contrast") <- contrast_desc
attr(object$marginal_se, "type") <- attr(object$robust_varcov, "type")

return(object)
Expand Down
Loading

0 comments on commit 425a5ad

Please sign in to comment.