Skip to content

Commit

Permalink
All methods accessible via causens API (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
larryshamalama authored Jul 2, 2024
1 parent f36a45c commit 53dfdf9
Show file tree
Hide file tree
Showing 13 changed files with 111 additions and 48 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ Description: This package provides functionality to perform causal
License: MIT
URL: https://kuan-liu-lab.github.io/causens/, https://github.com/Kuan-Liu-Lab/causens
Collate:
"bayesian_causens.R"
"causens_bayesian.R"
"causens_sf.R"
"causens.R"
"plot.R"
"sensitivity_function.R"
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Generated by roxygen2: do not edit by hand

export(causens)
export(causens_sf)
export(gData_U_binary_Y_binary)
export(gData_U_binary_Y_cont)
export(gData_U_cont_Y_binary)
Expand Down
38 changes: 12 additions & 26 deletions R/causens.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
#' to the sensitivity function `sf`.
#'
#' @param trt_model The treatment model object as a formula or fitted glm.
#' @param data A data frame containing the variables of interest.
#' @param outcome The name of the outcome variable.
#' @param method The method to use for sensitivity analysis. Currently, only
#' "Li" is supported.
#' @param method The method to use for sensitivity analysis. Currently, "Li" and
#' "Bayesian" are supported.
#' @param data A data frame containing the exposure, outcome, and confounder variables.
#' @param ... Additional arguments to be passed to the sensitivity function.
#'
#' @return A point estimate of the corrected ATE.
#'
#' @export
causens <- function(trt_model, data, outcome, method, ...) {
y <- data[[outcome]]
causens <- function(trt_model, outcome, method, data, ...) {

if (inherits(trt_model, "formula")) {
fitted_model <- glm(trt_model, data = data, family = binomial)
Expand All @@ -30,32 +29,19 @@ causens <- function(trt_model, data, outcome, method, ...) {
stop("Treatment model must be a formula or a glm object.")
}

z_index <- attr(terms(trt_formula), "response")
z <- data[[all.vars(trt_formula)[[z_index]]]]
trt_index <- attr(terms(trt_formula), "response")
trt_var_name <- all.vars(trt_formula)[[trt_index]]

e <- predict(fitted_model, type = "response")
method <- tolower(method) # case-insensitive

if (method == "Li") {
c1 <- sf(z = 1, e = e, ...)
c0 <- sf(z = 0, e = 1 - e, ...)
if (method == "sf" || method == "li") {
estimated_ate <- causens_sf(fitted_model, trt_var_name, outcome, data, ...)
} else if (method == "bayesian") {
confounder_names <- attr(terms(trt_model), "term.labels")
estimated_ate <- bayesian_causens(trt_var_name, outcome, confounder_names, data, ...)
} else {
stop("Method not recognized or not implemented yet.")
}

# Calculate the Average Treatment Effect
weights <- 1 / ifelse(z == 1, e, 1 - e)

if (all(y %in% c(0, 1))) {
Y_sf <- y * (abs(1 - z - e) + exp((-1)**(z == 1) * ifelse(z, c1, c0) * abs(z - e)))
} else {
Y_sf <- y + (-1)**(z == 1) * abs(z - e) * ifelse(z, c1, c0)
}

# Potential outcomes corrected w.r.t. sensitivity function
Y1_sf <- sum((Y_sf * weights)[z == 1]) / sum(weights[z == 1])
Y0_sf <- sum((Y_sf * weights)[z == 0]) / sum(weights[z == 0])

estimated_ate <- Y1_sf - Y0_sf

return(estimated_ate)
}
4 changes: 2 additions & 2 deletions R/bayesian_causens.R → R/causens_bayesian.R
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
#' @title Bayesian parametric sensitivity analysis for causal inference
#' @description This function runs a Bayesian sensitivity analysis for causal
#' inference using JAGS or Stan as a backend.
#' @param data A data frame containing the exposure, outcome, and confounders.
#' @param exposure The name of the exposure variable in the data frame.
#' @param outcome The name of the outcome variable in the data frame.
#' @param confounders The name of the confounders in the data frame.
#' @param data A data frame containing the exposure, outcome, and confounder variables.
#' @param backend The backend to use for the sensitivity analysis. Currently
#' only "jags" is supported.
#' @param output_trace Whether to output the full trace of the MCMC sampler.
#' @param ... Additional arguments to be passed to the backend.
#' @return A list of posterior samples for the causal effect of the exposure
#' variable on the outcome, as well as the confounder-adjusted causal effect.
bayesian_causens <- function(data, exposure, outcome, confounders, backend = "jags", output_trace = FALSE, ...) {
bayesian_causens <- function(exposure, outcome, confounders, data, backend = "jags", output_trace = FALSE, ...) {
if (backend == "rjags" || backend == "jags") {
require(rjags)
} else if (backend == "stan" || backend == "rstan") {
Expand Down
39 changes: 39 additions & 0 deletions R/causens_sf.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#' @title Bayesian Estimation of ATE Subject to Unmeasured Confounding
#'
#' @description This function provides an estimate of the Average Treatment
#' Effect (ATE) using Bayesian modelling.
#'
#' @param fitted_model The treatment model object as a glm.
#' @param exposure The name of the exposure variable.
#' @param outcome The name of the outcome variable.
#' @param data A data frame containing the exposure, outcome, and confounder variables.
#' @param ... Additional arguments to be passed to the sensitivity function.
#'
#' @return A point estimate of the corrected ATE.
#' @export
causens_sf <- function(fitted_model, exposure, outcome, data, ...) {
y <- data[[outcome]]
z <- data[[exposure]]

e <- predict(fitted_model, type = "response")

c1 <- sf(z = 1, e = e, ...)
c0 <- sf(z = 0, e = 1 - e, ...)

# Calculate the Average Treatment Effect
weights <- 1 / ifelse(z == 1, e, 1 - e)

if (all(y %in% c(0, 1))) {
Y_sf <- y * (abs(1 - z - e) + exp((-1)**(z == 1) * ifelse(z, c1, c0) * abs(z - e)))
} else {
Y_sf <- y + (-1)**(z == 1) * abs(z - e) * ifelse(z, c1, c0)
}

# Potential outcomes corrected w.r.t. sensitivity function
Y1_sf <- sum((Y_sf * weights)[z == 1]) / sum(weights[z == 1])
Y0_sf <- sum((Y_sf * weights)[z == 0]) / sum(weights[z == 0])

estimated_ate <- Y1_sf - Y0_sf

return(estimated_ate)
}
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ _Why is it that more shark attacks occur when more ice cream is sold? The answer

## Overview

causens is an R package that will allow to perform various sensitivity
`causens` is an R package that will allow to perform various sensitivity
analysis methods to adjust for unmeasured confounding within the context of
causal inference.
causal inference. Currently, we provide the following methods:

- Sensitivity function + propensity score ([Li et al. (2011)](https://pubmed.ncbi.nlm.nih.gov/21659349/), [Brumback et al. (2004)](https://onlinelibrary.wiley.com/doi/10.1002/sim.1657))
- Bayesian parametric sensitivity analysis ([McCandless et Gustafson (2017)](https://onlinelibrary.wiley.com/doi/abs/10.1002/sim.7298))

## Installation

Expand Down
8 changes: 4 additions & 4 deletions man/bayesian_causens.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions man/causens.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 26 additions & 0 deletions man/causens_sf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/create_jags_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test_bayesian_sa.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ run_simulation <- function(seed) {
seed = seed, treatment_effects = trt_effect
)

return(bayesian_causens(data, "Z", "Y", c("X.1", "X.2", "X.3"), ))
return(causens(Z ~ X.1 + X.2 + X.3, "Y", data, method = "Bayesian"))
}

simulated_ate <- c()
Expand Down
17 changes: 12 additions & 5 deletions tests/testthat/test_causens.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ for (params in parameters) {

trt_model <- Z ~ X.1 + X.2 + X.3

return(causens(trt_model, data, "Y", method = "Li", c1 = c1, c0 = c0))
return(causens(trt_model, "Y", method = "Li", data = data, c1 = c1, c0 = c0))
}

# Because alpha_uz > 0 and beta_uy > 0, treated individuals are more likely to
Expand All @@ -61,8 +61,15 @@ for (params in parameters) {
# Testing `trt_model` input types
trt_model <- Z ~ X.1 + X.2 + X.3

est_ate_1 <- causens(trt_model, data, "Y", method = "Li", c1 = 0.25, c0 = 0.25)
est_ate_2 <- causens(glm(trt_model, data = data, family = binomial()), data, "Y", method = "Li", c1 = 0.25, c0 = 0.25)
est_ate_1 <- causens(trt_model, "Y", method = "Li", data = data, c1 = 0.25, c0 = 0.25)
est_ate_2 <- causens(
glm(trt_model, data = data, family = binomial()),
"Y",
method = "Li",
data = data,
c1 = 0.25,
c0 = 0.25
)

test_that("trt_model can be a formula or fitted glm model", {
expect_equal(est_ate_1, est_ate_2)
Expand All @@ -74,14 +81,14 @@ test_that("causens throws an error if `trt_model` input is invalid", {
trt_model <- "Z ~ X.1 + X.2 + X.3"

expect_error(
object = causens(trt_model, data, "Y", method = "Li", c1 = 0.25, c0 = 0.25),
object = causens(trt_model, "Y", method = "Li", data = data, c1 = 0.25, c0 = 0.25),
regexp = "Treatment model must be a formula or a glm object."
)
})

test_that("causens throws an error if `method` input is invalid", {
expect_error(
object = causens(Z ~ 1, data, "Y", method = "???", c1 = 0.25, c0 = 0.25),
object = causens(Z ~ 1, "Y", method = "???", data = data, c1 = 0.25, c0 = 0.25),
regexp = "Method not recognized or not implemented yet."
)
})
2 changes: 1 addition & 1 deletion tests/testthat/test_causens_binary.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ for (params in parameters) {

trt_model <- Z ~ X.1 + X.2 + X.3

return(causens(trt_model, data, "Y", method = "Li", c1 = 0, c0 = 0))
return(causens(trt_model, "Y", data, method = "Li", c1 = 0, c0 = 0))
}

simulated_ates <- unlist(lapply(1:1000, run_simulation))
Expand Down

0 comments on commit 53dfdf9

Please sign in to comment.