Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Anticipation #5

Merged
merged 16 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 71 additions & 30 deletions R/aggregate_gt.R
Original file line number Diff line number Diff line change
@@ -1,52 +1,60 @@
aggregate_gt <- function(gt_result, cohort_sizes,
id_weights, id_cohorts,
result_type, balanced_event_time){
aggregate_gt <- function(gt_result, aux, p){

gt_att <- gt_result$att
gt_inf_func <- gt_result$inf_func
gt <- gt_result$gt

#release the stuff
id_cohorts <- aux$dt_inv[, G]

id_dt <- data.table(weight = id_weights/sum(id_weights), G = id_cohorts)
id_dt <- data.table(weight = aux$weights/sum(aux$weights), G = id_cohorts)
pg_dt <- id_dt[, .(pg = sum(weight)), by = "G"]
group_time <- gt |> merge(pg_dt, by = "G")
group_time <- gt_result$gt |> merge(pg_dt, by = "G")

setorder(group_time, time, G) #change the order to match the order in gtatt

gt_inf_func <- as.matrix(gt_inf_func)
gt_result$inf_func <- as.matrix(gt_result$inf_func)

if(result_type == "group_time"){
if(p$result_type == "group_time"){

#don't need to do anything
targets <- group_time[, unique(G*max(time)+time)]
inf_matrix <- gt_inf_func
agg_att <- as.vector(gt_att)
inf_matrix <- gt_result$inf_func
agg_att <- as.vector(gt_result$att)

} else {

agg_sch <- get_aggregate_scheme(group_time, result_type, id_weights, id_cohorts, balanced_event_time)
#get which gt(s) is a part of the aggregated param
agg_sch <- get_aggregate_scheme(group_time, p$result_type, aux$weights, id_cohorts, p$balanced_event_time)
targets <- agg_sch$targets
weights <- as.matrix(agg_sch$weights)

#aggregated att
agg_att <- weights %*% gt_att
agg_att <- agg_sch$agg_weights %*% gt_result$att

#get the influence from weight estimation
#this needs to be optimized!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
inf_weights <- sapply(asplit(weights, 1), function (x){
get_weight_influence(x, gt_att, id_weights, id_cohorts, group_time[, .(G, time)])
inf_weights <- sapply(asplit(agg_sch$agg_weights, 1), function (x){
get_weight_influence(x, gt_result$att, aux$weights, id_cohorts, group_time[, .(G, time)])
})

#aggregated influence function
inf_matrix <- (gt_inf_func %*% t(weights)) + inf_weights

inf_matrix <- (gt_result$inf_func %*% t(agg_sch$agg_weights)) + inf_weights
}
return(list(inf_matrix = inf_matrix, agg_att = agg_att, targets = targets))

#get se from influence function
agg_se <- get_se(inf_matrix, p$boot, p$biters, aux$cluster, p$clustervar)

# post process
result <- data.table(targets, agg_att, agg_se)
names(result) <- c("target", "att", "se")
result[,outcome := gt_result$outname]

return(result)
}

get_aggregate_scheme <- function(group_time, result_type, id_weights, id_cohorts, balanced_event_time){
get_aggregate_scheme <- function(group_time, result_type, weights, id_cohorts, balanced_event_time){

#browser()

weights <- data.table()
agg_weights <- data.table()
gt_count <- group_time[, .N]

bool_to_pn <- function(x){ifelse(x, 1, -1)}
Expand All @@ -67,7 +75,7 @@ get_aggregate_scheme <- function(group_time, result_type, id_weights, id_cohorts

#for balanced cohort composition in dynamic setting
#a cohort us only used if it is seen for all dynamic time
if(result_type == "dynamic" & !is.null(balanced_event_time)){
if(result_type == "dynamic" & !is.na(balanced_event_time)){

cohorts <- group_time[, .(max_et = max(time-G),
min_et = min(time-G)), by = "G"]
Expand All @@ -89,18 +97,18 @@ get_aggregate_scheme <- function(group_time, result_type, id_weights, id_cohorts

group_time[, targeted := NULL]

weights <- rbind(weights, target_weights)
agg_weights <- rbind(agg_weights, target_weights)
}

return(list(weights = weights, #a matrix of each target and gt's weight in it
return(list(agg_weights = as.matrix(agg_weights), #a matrix of each target and gt's weight in it
targets = targets))
}

get_weight_influence <- function(agg_weights, gt_att, id_weights, id_cohorts, group) {
get_weight_influence <- function(agg_weights, gt_att, weights, id_cohorts, group) {

keepers <- which(agg_weights > 0)

id_dt <- data.table(weight = id_weights/sum(id_weights), G = id_cohorts)
id_dt <- data.table(weight = weights/sum(weights), G = id_cohorts)
pg_dt <- id_dt[, .(pg = sum(weight)), by = "G"]
group <- group |> merge(pg_dt, by = "G")

Expand All @@ -110,16 +118,49 @@ get_weight_influence <- function(agg_weights, gt_att, id_weights, id_cohorts, gr

# effect of estimating weights in the numerator
if1 <- sapply(keepers, function(k) {
(id_weights*BMisc::TorF(id_cohorts == group[k,G]) - group[k,pg]) /
(weights*BMisc::TorF(id_cohorts == group[k,G]) - group[k,pg]) /
sum(group[keepers,pg])
})
# effect of estimating weights in the denominator
if2 <- base::rowSums(sapply(keepers, function(k) {
id_weights*BMisc::TorF(id_cohorts == group[k,G]) - group[k,pg]
weights*BMisc::TorF(id_cohorts == group[k,G]) - group[k,pg]
})) %*%
t(group[keepers,pg]/(sum(group[keepers,pg])^2))
# return the influence function for the weights
inf_weight <- (if1 - if2) %*% as.vector(gt_att[keepers])
inf_weight[abs(inf_weight) < sqrt(.Machine$double.eps)*10] <- 0 #fill zero
return(inf_weight)
}
}

get_se <- function(inf_matrix, boot, biters, cluster, clustervar) {

if(boot){

top_quant <- 0.75
bot_quant <- 0.25
if(!allNA(clustervar)){
#take average within the cluster
cluster_n <- stats::aggregate(cluster, by=list(cluster), length)[,2]
inf_matrix <- fsum(inf_matrix, cluster) / cluster_n #the mean without 0 for each cluster of each setting
}

boot_results <- BMisc::multiplier_bootstrap(inf_matrix, biters = biters) %>% as.data.table()

boot_top <- boot_results[, lapply(.SD, function(x) stats::quantile(x, top_quant, type=1, na.rm = TRUE))]
boot_bot <- boot_results[, lapply(.SD, function(x) stats::quantile(x, bot_quant, type=1, na.rm = TRUE))]

dt_se <- rbind(boot_bot, boot_top) %>% transpose()
names(dt_se) <- c("boot_bot", "boot_top")

se <- dt_se[,(boot_top-boot_bot)/(qnorm(top_quant) - qnorm(bot_quant))]
se[se < sqrt(.Machine$double.eps)*10] <- NA

} else {

inf_matrix <- inf_matrix |> as.data.table()
se <- inf_matrix[, lapply(.SD, function(x) sqrt(sum(x^2, na.rm = TRUE)/length(x)^2))] %>% as.vector() #should maybe use n-1 but did use n

}
return(unlist(se))
}

37 changes: 22 additions & 15 deletions R/estimate_did.R
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
estimate_did <- function(dt_did, covnames, control_type,
estimate_did <- function(dt_did, covvars, control_type,
last_coef = NULL, cache_ps_fit, cache_hess){

# preprocess --------

oldn <- dt_did[, .N]
data_pos <- which(dt_did[, !is.na(D)])
dt_did <- dt_did[data_pos]
n <- dt_did[, .N]
if(is.null(covnames)){
covvars <- NULL

if(is.matrix(covvars)){
ipw <- control_type %in% c("ipw", "dr")
or <- control_type %in% c("reg", "dr")
covvars <- covvars[data_pos,]
} else {
covvars <- as.matrix(dt_did[,.SD, .SDcols = covnames])
ipw <- FALSE
or <- FALSE
}

ipw <- control_type %in% c("ipw", "dr") & !is.null(covvars)
or <- control_type %in% c("reg", "dr") & !is.null(covvars) #OR is REG


# ipw --------

if(ipw){
Expand All @@ -29,19 +29,25 @@ estimate_did <- function(dt_did, covnames, control_type,
intercept = FALSE))
class(prop_score_est) <- "glm" #trick the vcov function to think that this is a glm object to dispatch the write method
#const is implicitly put into the ipw formula, need to incorporate it manually
hess <- stats::vcov(prop_score_est) * n #for the influence function

logit_coef <- prop_score_est$coefficients

if(anyNA(logit_coef)){
warning("some propensity score estimation resulted in NA coefficients, likely cause by perfect colinearity")
}

logit_coef[is.na(logit_coef)|abs(logit_coef) > 1e10] <- 0 #put extreme value and na to 0
prop_score_fit <- fitted(prop_score_est)
if(max(prop_score_fit) >= 1){warning(paste0("support overlap condition violated for some group_time"))}
prop_score_fit <- pmin(1-1e-16, prop_score_fit) #for the ipw

hess <- stats::vcov(prop_score_est) * n #for the influence function
hess[is.na(hess)|abs(hess) > 1e10] <- 0

} else { #when using multiple outcome, ipw cache can be reused
hess <- cache_hess
prop_score_fit <- cache_ps_fit
logit_coef <- NULL #won't be needing the approximate cache
logit_coef <- NA #won't be needing the approximate cache
}

#get the results into the main did dt
Expand All @@ -52,8 +58,8 @@ estimate_did <- function(dt_did, covnames, control_type,
} else {

prop_score_fit <- rep(1,n)
logit_coef <- NULL
hess <- NULL
logit_coef <- NA
hess <- NA

dt_did[, treat_ipw_weight := weights*D]
dt_did[, cont_ipw_weight := weights*(1-D)]
Expand Down Expand Up @@ -104,6 +110,7 @@ estimate_did <- function(dt_did, covnames, control_type,
M2 <- colMeans(dt_did[, cont_ipw_weight*(delta_y-weighted_cont_delta-or_delta)] * covvars)

score_ps <- dt_did[, weights*(D-ps)] * covvars

asym_linear_ps <- score_ps %*% hess

#ipw for control
Expand Down Expand Up @@ -140,6 +147,7 @@ estimate_did <- function(dt_did, covnames, control_type,


#get overall influence function
#if(dt_did[, mean(cont_ipw_weight)] < 1e-10){warning("little/no overlap in covariates between control and treat group, estimates are unstable.")}
inf_cont <- (inf_cont_did+inf_cont_ipw+inf_cont_or)/dt_did[, mean(cont_ipw_weight)]
inf_treat <- (inf_treat_did+inf_treat_or)/dt_did[,mean(treat_ipw_weight)]
inf_func_no_na <- inf_treat - inf_cont
Expand All @@ -149,7 +157,6 @@ estimate_did <- function(dt_did, covnames, control_type,
inf_func_no_na <- inf_func_no_na * oldn / n #adjust the value such that mean over the whole id size give the right result
inf_func[data_pos] <- inf_func_no_na


return(list(att = att, inf_func = inf_func, logit_coef = logit_coef, #for next gt
cache_ps_fit = prop_score_fit, cache_hess = hess)) #for next outcome
}
27 changes: 13 additions & 14 deletions R/estimate_did_rc.R
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
estimate_did_rc <- function(dt_did, covnames, control_type,
estimate_did_rc <- function(dt_did, covvars, control_type,
last_coef = NULL, cache_ps_fit, cache_hess){

#TODO: skip if not enough valid data

# preprocess --------

oldn <- dt_did[, .N]

#separate the dataset into pre and post

oldn <- dt_did[, .N]
data_pos <- which(dt_did[, !is.na(D)])
dt_did <- dt_did[data_pos]
n <- dt_did[, .N]

#separate the dataset into pre and post
dt_did[, inpre := as.numeric(!is.na(pre.y))]
dt_did[, inpost := as.numeric(!is.na(post.y))]
n_pre <- dt_did[, sum(!is.na(pre.y))]
Expand All @@ -21,15 +20,15 @@ estimate_did_rc <- function(dt_did, covnames, control_type,
sum_weight_pre <- dt_did[, sum(inpre*weights)]
sum_weight_post <- dt_did[, sum(inpost*weights)]

if(is.null(covnames)){
covvars <- NULL
if(is.matrix(covvars)){
ipw <- control_type %in% c("ipw", "dr")
or <- control_type %in% c("reg", "dr")
covvars <- covvars[data_pos,]
} else {
covvars <- as.matrix(dt_did[,.SD, .SDcols = covnames])
ipw <- FALSE
or <- FALSE
}

ipw <- control_type %in% c("ipw", "dr") & !is.null(covvars)
or <- control_type %in% c("reg", "dr") & !is.null(covvars) #OR is REG


# ipw --------

if(ipw){
Expand Down Expand Up @@ -62,8 +61,8 @@ estimate_did_rc <- function(dt_did, covnames, control_type,
} else {

prop_score_fit <- rep(1,n)
logit_coef <- NULL
hess <- NULL
logit_coef <- NA
hess <- NA
dt_did[, treat_ipw_weight := weights*D]
dt_did[, cont_ipw_weight := weights*(1-D)]

Expand Down Expand Up @@ -191,8 +190,8 @@ estimate_did_rc <- function(dt_did, covnames, control_type,
inf_treat_post <- inf_treat_did_post+inf_treat_or_post
inf_cont_pre <- inf_cont_did_pre+inf_cont_ipw_pre+inf_cont_or_pre
inf_treat_pre <- inf_treat_did_pre+inf_treat_or_pre

#post process

inf_func_no_na_post <- (inf_treat_post - inf_cont_post) * oldn / n_post #adjust the value such that mean over the whole id size give the right result
inf_func_no_na_post[is.na(inf_func_no_na_post)] <- 0 #fill 0 for NA part (no influce if not in this gt)

Expand Down
Loading
Loading