Skip to content

Commit

Permalink
0.9.3.3: validate interface improve, anticipation, group_time target
Browse files Browse the repository at this point in the history
  • Loading branch information
TsaiLintung committed Jul 22, 2024
1 parent 51f7f7f commit 36a8de2
Show file tree
Hide file tree
Showing 8 changed files with 129 additions and 137 deletions.
5 changes: 2 additions & 3 deletions R/aggregate_gt.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ aggregate_gt <- function(gt_result, aux, p){
group_time <- gt_result$gt |> merge(pg_dt, by = "G")

setorder(group_time, time, G) #change the order to match the order in gtatt

gt_result$inf_func <- as.matrix(gt_result$inf_func)

if(p$result_type == "group_time"){

#don't need to do anything
targets <- group_time[, unique(G*max(time)+time)]
targets <- group_time[, paste0(G, ".", time)]
inf_matrix <- gt_result$inf_func
agg_att <- as.vector(gt_result$att)

Expand Down
5 changes: 4 additions & 1 deletion R/estimate_gtatt.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ estimate_gtatt <- function(aux, p) {
#assign cache for next outcome
if(is.null(cache_ps_fit_list[[gt_name]])){cache_ps_fit_list[[gt_name]] <- result$cache_ps_fit}
if(is.null(cache_hess_list[[gt_name]])){cache_hess_list[[gt_name]] <- result$cache_hess}

rm(result)

}
Expand All @@ -67,6 +68,7 @@ estimate_gtatt <- function(aux, p) {
}

return(outcome_result_list)

}

get_did_setup <- function(g, t, base_period, aux, p){
Expand Down Expand Up @@ -94,7 +96,8 @@ get_did_setup <- function(g, t, base_period, aux, p){
if(t == base_period | #no treatment effect for the base period
base_period < min(aux$time_periods) | #no treatment effect for the first period, since base period is not observed
g >= max_control_cohort | #no treatment effect for never treated or the last treated cohort (for not yet notyet)
t >= max_control_cohort){ #no control available if the last cohort is treated too
t >= max_control_cohort | #no control available if the last cohort is treated too
min_control_cohort > max_control_cohort){ #no control avalilble, most likely due to anticipation
return(NULL)
} else {
#select the control and treated cohorts
Expand Down
90 changes: 14 additions & 76 deletions R/fastdid.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,86 +69,25 @@ fastdid <- function(data,
copy = TRUE, validate = TRUE,
max_control_cohort_diff = Inf, anticipation = 0, min_control_cohort_diff = -Inf, base_period = "universal"
){

# validate arguments --------------------------------------------------------

# validation --------------------------------------------------------

if(!is.data.table(data)){
warning("coercing input into a data.table.")
data <- as.data.table(data)
}
if(copy){dt <- copy(data)} else {dt <- data}

# validate arguments
p <- as.list(environment()) #collect everything besides data
p$data <- NULL
validate_argument(p, names(data))

dt_names <- names(dt)
name_message <- "__ARG__ must be a character scalar and a name of a column from the dataset."
check_set_arg(timevar, unitvar, cohortvar, "match", .choices = dt_names, .message = name_message)

covariate_message <- "__ARG__ must be NA or a character vector which are all names of columns from the dataset."
check_set_arg(varycovariatesvar, covariatesvar, outcomevar,
"NA | multi match", .choices = dt_names, .message = covariate_message)

checkvar_message <- "__ARG__ must be NA or a character scalar if a name of columns from the dataset."
check_set_arg(weightvar, clustervar, filtervar,
"NA | match", .choices = dt_names, .message = checkvar_message)

check_set_arg(control_option, "match", .choices = c("both", "never", "notyet")) #kinda bad names since did's notyet include both notyet and never
check_set_arg(control_type, "match", .choices = c("ipw", "reg", "dr"))
check_set_arg(base_period, "match", .choices = c("varying", "universal"))
check_arg(copy, validate, boot, allow_unbalance_panel, "scalar logical")
check_arg(max_control_cohort_diff, min_control_cohort_diff, anticipation, "scalar numeric")

if(!is.na(balanced_event_time)){
if(result_type != "dynamic"){stop("balanced_event_time is only meaningful with result_type == 'dynamic'")}
check_arg(balanced_event_time, "numeric scalar")
}
if(allow_unbalance_panel == TRUE & control_type %in% c("dr", "reg")){
stop("fastdid currently only supprts ipw when allowing for unbalanced panels.")
}
if(allow_unbalance_panel == TRUE & !allNA(varycovariatesvar)){
stop("fastdid currently only supprts time varying covariates when allowing for unbalanced panels.")
}
if(any(covariatesvar %in% varycovariatesvar) & !allNA(varycovariatesvar) & !allNA(covariatesvar)){
stop("time-varying var and invariant var have overlaps.")
}
if(!boot & !allNA(clustervar)){
stop("clustering only available with bootstrap")
}

# coerce non-sensible option
if(!is.na(clustervar) && unitvar == clustervar){clustervar <- NA} #cluster on id anyway, would cause error otherwise
if((!is.infinite(max_control_cohort_diff) | !is.infinite(min_control_cohort_diff)) & control_option == "never"){
warning("control_cohort_diff can only be used with not yet")
p$control_option <- "notyet"
}

p <- list(timevar = timevar,
cohortvar = cohortvar,
unitvar = unitvar,
outcomevar = outcomevar,
weightvar = weightvar,
clustervar = clustervar,
filtervar = filtervar,
covariatesvar = covariatesvar,
varycovariatesvar = varycovariatesvar,
control_option = control_option,
result_type = result_type,
balanced_event_time = balanced_event_time,
control_type = control_type,
allow_unbalance_panel = allow_unbalance_panel,
boot = boot,
biters = biters,
max_control_cohort_diff = max_control_cohort_diff,
min_control_cohort_diff = min_control_cohort_diff,
anticipation = anticipation,
base_period = base_period)


# validate data -----------------------------------------------------

# validate data
setnames(dt, c(timevar, cohortvar, unitvar), c("time", "G", "unit"))

if(validate){
varnames <- c("time", "G", "unit",outcomevar,weightvar,clustervar,covariatesvar,varycovariatesvar,filtervar)
dt <- validate_did(dt, varnames, p)
varnames <- c("time", "G", "unit", outcomevar, weightvar, clustervar, covariatesvar, varycovariatesvar, filtervar)
dt <- validate_dt(dt, varnames, p)
}

# preprocess -----------------------------------------------------------
Expand Down Expand Up @@ -183,6 +122,8 @@ fastdid <- function(data,

# small steps ----------------------------------------------------------------------



coerce_dt <- function(dt, p){

#change to int before sorting
Expand Down Expand Up @@ -225,7 +166,6 @@ coerce_dt <- function(dt, p){
return(list(dt = dt,
time_change = list(time_step = time_step,
max_time = max(time_periods),
last_treated_cohort = ifelse(p$control_option == "notyet", dt[!is.infinite(G),max(G)], dt[,max(G)]),
time_offset = time_offset)))

}
Expand Down Expand Up @@ -353,10 +293,8 @@ convert_targets <- function(results, result_type, t){

} else if (result_type == "group_time"){

max_avail_time <- min(t$max_time, t$last_treated_cohort-1)

results[, cohort := floor((target-1)/max_avail_time)]
results[, time := (target-cohort*max_avail_time)]
results[, cohort := as.numeric(str_split_i(target, "\\.", 1))]
results[, time := as.numeric(str_split_i(target, "\\.", 2))]

#recover the time
results[, cohort := recover_time(cohort, t$time_offset, t$time_step)]
Expand Down
5 changes: 4 additions & 1 deletion R/global.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,8 @@ utils::globalVariables(c('.','agg_weight','att','att_cont','att_treat','attgt','
'post.y','pre.y','ps','s','se','target','tau','time_fe',
'treat_ipw_weight','treat_latent','type','unit','unit_fe','weight','x','x2',
'x_trend','y','y0','y1','y2', 'time', 'weights', 'outcome', "G", "D", 'xvar',
'V1','att_cont_post','att_cont_pre','att_treat_post','att_treat_pre','inpost','inpre','max_et','min_et','new_unit','or_delta','or_delta_post','or_delta_pre','targeted','used'))
'V1','att_cont_post','att_cont_pre','att_treat_post','att_treat_pre','inpost','inpre','max_et','min_et','new_unit','or_delta','or_delta_post','or_delta_pre','targeted','used',
"timevar", "cohortvar", "unitvar", "outcomevar", "control_option", "result_type", "balanced_event_time", "control_type",
"allow_unbalance_panel", "boot", "biters", "weightvar", "clustervar", "covariatesvar", "varycovariatesvar", "filtervar",
"copy", "validate", "max_control_cohort_diff", "anticipation", "min_control_cohort_diff", "base_period"))

103 changes: 103 additions & 0 deletions R/validate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
validate_argument <- function(p, dt_names){

#release p
for(name in names(p)){
assign(name, p[[name]])
}

name_message <- "__ARG__ must be a character scalar and a name of a column from the dataset."
check_set_arg(timevar, unitvar, cohortvar, "match", .choices = dt_names, .message = name_message)

covariate_message <- "__ARG__ must be NA or a character vector which are all names of columns from the dataset."
check_set_arg(varycovariatesvar, covariatesvar, outcomevar,
"NA | multi match", .choices = dt_names, .message = covariate_message)

checkvar_message <- "__ARG__ must be NA or a character scalar if a name of columns from the dataset."
check_set_arg(weightvar, clustervar, filtervar,
"NA | match", .choices = dt_names, .message = checkvar_message)

check_set_arg(control_option, "match", .choices = c("both", "never", "notyet")) #kinda bad names since did's notyet include both notyet and never
check_set_arg(control_type, "match", .choices = c("ipw", "reg", "dr"))
check_set_arg(base_period, "match", .choices = c("varying", "universal"))
check_arg(copy, validate, boot, allow_unbalance_panel, "scalar logical")
check_arg(max_control_cohort_diff, min_control_cohort_diff, anticipation, "scalar numeric")

if(!is.na(balanced_event_time)){
if(result_type != "dynamic"){stop("balanced_event_time is only meaningful with result_type == 'dynamic'")}
check_arg(balanced_event_time, "numeric scalar")
}
if(allow_unbalance_panel == TRUE & control_type %in% c("dr", "reg")){
stop("fastdid currently only supprts ipw when allowing for unbalanced panels.")
}
if(allow_unbalance_panel == TRUE & !allNA(varycovariatesvar)){
stop("fastdid currently only supprts time varying covariates when allowing for unbalanced panels.")
}
if(any(covariatesvar %in% varycovariatesvar) & !allNA(varycovariatesvar) & !allNA(covariatesvar)){
stop("time-varying var and invariant var have overlaps.")
}
if(!boot & !allNA(clustervar)){
stop("clustering only available with bootstrap")
}

# coerce non-sensible option
if(!is.na(clustervar) && unitvar == clustervar){clustervar <- NA} #cluster on id anyway, would cause error otherwise
if((!is.infinite(max_control_cohort_diff) | !is.infinite(min_control_cohort_diff)) & control_option == "never"){
warning("control_cohort_diff can only be used with not yet")
p$control_option <- "notyet"
}
}

validate_dt <- function(dt,varnames,p){

raw_unit_size <- dt[, uniqueN(unit)]
raw_time_size <- dt[, uniqueN(time)]

if(!is.na(p$balanced_event_time)){
if(p$balanced_event_time > dt[, max(time-G)]){stop("balanced_event_time is larger than the max event time in the data")}
}

if(!is.na(p$filtervar) && !is.logical(dt[[p$filtervar]])){
stop("filter var needs to be a logical column")
}

#doesn't allow missing value for now
for(col in varnames){
if(is.na(col)){next}
na_obs <- whichNA(dt[[col]])
if(length(na_obs) != 0){
warning("missing values detected in ", col, ", removing ", length(na_obs), " observation.")
dt <- dt[!na_obs]
}
}

if(!allNA(p$covariatesvar) && uniqueN(dt, by = c("unit", p$covariatesvar)) > raw_unit_size){
warning("some covariates is time-varying, fastdid only use the first observation for covariates.")
}


if(!allNA(p$covariatesvar)|!allNA(p$varycovariatesvar)){
for(cov in c(p$covariatesvar, p$varycovariatesvar)){
if(is.na(cov)){next}
#check covaraites is not constant
if(fnunique(dt[, get(cov)[1], by = "unit"][, V1]) == 1)stop(cov, " have no variation")
}
}

#check balanced panel
#check if any is dup
if(anyDuplicated(dt[, .(unit, time)])){
dup_id <- dt[duplicated(dt[,.(unit, time)]), unique(unit)]
stop(length(dup_id), " units is observed more than once in a period.")
}

#check if any is missing
if(!p$allow_unbalance_panel){
unit_count <- dt[, .(count = .N), by = unit]
if(any(unit_count[, count < raw_time_size])){
mis_unit <- unit_count[count < raw_time_size]
warning(nrow(mis_unit), " units is missing in some periods, enforcing balanced panel by dropping them")
dt <- dt[!unit %in% mis_unit[, unit]]
}
}
return(dt)
}
54 changes: 0 additions & 54 deletions R/validate_did.R

This file was deleted.

3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ Since **fastdid** is not on CRAN yet, it needs to be converted to R scripts to b
- add time-varying control ([reference](https://arxiv.org/abs/2202.02903))
- add filtervar

0.9.3.1 (2024/5/24): fix the bug with `univar == clustervar` (TODO: address problems with name-changing and collision)
0.9.3.1 (2024/5/24): fix the bug with `univar == clustervar` (TODO: address problems with name-changing and collision).
0.9.3.2 (2024/7/17): fix group_time result when using `control_type = "notyet"` and make the base period in plots adapt to anticipation.
0.9.3.3 (2024/7/22): fix anticipation out of bound problem, more permanent solution for group_time target problem

## 0.9.2 (2023/12/20)

Expand Down
1 change: 0 additions & 1 deletion development/buildtest.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ load_all()
run_test_dir()

#before release
build()
check()

0 comments on commit 36a8de2

Please sign in to comment.