diff --git a/.Rbuildignore b/.Rbuildignore index 9c86736..224a30a 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -5,3 +5,5 @@ ^pkg.Rproj$ figure$ cache$ +^.*\.Rproj$ +^\.Rproj\.user$ diff --git a/.gitignore b/.gitignore index f53558f..561017c 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,7 @@ inst/doc # R stuff *.Rout *.Rhistory -*.RData +*.RData *.Rapp.history # Mac stuff @@ -27,4 +27,10 @@ inst/doc test.R -*-vignette.pdf \ No newline at end of file +*-vignette.pdf +.Rproj.user + +# scratch folder +/scratch +tobacco_replication/tobacco_replication_cache +tobacco_replication/tobacco_replication_files/figure-html diff --git a/DESCRIPTION b/DESCRIPTION index 096311b..e2b0cf1 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -23,7 +23,7 @@ Remotes: License: MIT + file LICENSE Encoding: UTF-8 LazyData: true -RoxygenNote: 7.2.3 +RoxygenNote: 7.3.3 Suggests: testthat, CausalImpact, diff --git a/NAMESPACE b/NAMESPACE index c32e4f6..4d52250 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,13 @@ # Generated by roxygen2: do not edit by hand +S3method(dim,augsynth) +S3method(dim,multisynth) +S3method(n_time,augsynth) +S3method(n_time,multisynth) +S3method(n_treated,augsynth) +S3method(n_treated,multisynth) +S3method(n_unit,augsynth) +S3method(n_unit,multisynth) S3method(plot,augsynth) S3method(plot,augsynth_multiout) S3method(plot,multisynth) @@ -18,13 +26,26 @@ S3method(print,summary.multisynth) S3method(summary,augsynth) S3method(summary,augsynth_multiout) S3method(summary,multisynth) +export(RMSPE) +export(add_inference) export(augsynth) export(augsynth_multiout) +export(covariate_balance_table) +export(donor_table) +export(is.augsynth) +export(is_summary_augsynth) export(multisynth) +export(n_time) +export(n_treated) +export(n_unit) +export(permutation_plot) +export(placebo_distribution) export(rdirichlet_b) export(rmultinom_b) export(rwild_b) export(single_augsynth) +export(treated_table) +export(update_augsynth) import(dplyr) import(tidyr) importFrom(ggplot2,aes) diff --git a/R/augsynth.R b/R/augsynth.R index ab68883..c1e226a 100644 --- a/R/augsynth.R +++ b/R/augsynth.R @@ -1,10 +1,10 @@ -################################################################################ -## Main functions for single-period treatment augmented synthetic controls Method -################################################################################ + +#### Main functions for single-period treatment augmented synthetic controls Method #### + #' Fit Augmented SCM -#' +#' #' @param form outcome ~ treatment | auxillary covariates #' @param unit Name of unit column #' @param time Name of time column @@ -14,14 +14,13 @@ #' ridge=Ridge regression (allows for standard errors), #' none=No outcome model, #' en=Elastic Net, RF=Random Forest, GSYN=gSynth, -#' mcp=MCPanel, +#' mcp=MCPanel, #' cits=Comparitive Interuppted Time Series #' causalimpact=Bayesian structural time series with CausalImpact -#' @param scm Whether the SCM weighting function is used -#' @param fixedeff Whether to include a unit fixed effect, default F +#' @param scm Whether the SCM weighting function is used. If FALSE, then package will fit the outcome model, but not calculate new donor weights to match pre-treatment covariates. Instead, each donor unit will be equally weighted. If TRUE, weights on donor pool will be calculated. +#' @param fixedeff Whether to include a unit fixed effect, default F #' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted -#' @param ... optional arguments for outcome model -#' +#' @param plot Whether or not to return a plot of the augsynth model #' @return augsynth object that contains: #' \itemize{ #' \item{"weights"}{Ridge ASCM weights} @@ -30,12 +29,15 @@ #' \item{"mhat"}{Outcome model estimate} #' \item{"data"}{Panel data as matrices} #' } +#' @param ... optional arguments for outcome model #' @export single_augsynth <- function(form, unit, time, t_int, data, - progfunc = "ridge", - scm=T, - fixedeff = FALSE, - cov_agg=NULL, ...) { + progfunc = "ridge", + scm=T, + fixedeff = FALSE, + cov_agg=NULL, + ...) { + call_name <- match.call() form <- Formula::Formula(form) @@ -48,30 +50,39 @@ single_augsynth <- function(form, unit, time, t_int, data, wide <- format_data(outcome, trt, unit, time, t_int, data) synth_data <- do.call(format_synth, wide) - + treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit) - control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% - distinct(!!unit) %>% arrange(!!unit) %>% pull(!!unit) - ## add covariates + control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% + distinct(!!unit) %>% arrange(!!unit) %>% pull(!!unit) + ## add covariates if(length(form)[2] == 2) { Z <- extract_covariates(form, unit, time, t_int, data, cov_agg) } else { Z <- NULL } - + # fit augmented SCM - augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc, + augsynth <- fit_augsynth_internal(wide, synth_data, Z, progfunc, scm, fixedeff, ...) - + # add some extra data augsynth$data$time <- data %>% distinct(!!time) %>% - arrange(!!time) %>% pull(!!time) + arrange(!!time) %>% pull(!!time) augsynth$call <- call_name - augsynth$t_int <- t_int - + augsynth$t_int <- t_int + augsynth$weights <- matrix(augsynth$weights) rownames(augsynth$weights) <- control_units + # TODO update similar attribute for multi <- if want to use for plotting later + augsynth$trt_unit <- data %>% filter(!!as.name(trt) == 1) %>% + pull(quo_name(unit)) %>% unique() + augsynth$time_var <- quo_name(time) + augsynth$unit_var <- quo_name(unit) + augsynth$raw_data <- data + augsynth$form <- form + augsynth$cov_agg <- cov_agg + return(augsynth) } @@ -85,14 +96,19 @@ single_augsynth <- function(form, unit, time, t_int, data, #' @param fixedeff Whether to de-mean synth #' @param V V matrix for Synth, default NULL #' @param ... Extra args for outcome model -#' +#' #' @noRd -#' +#' fit_augsynth_internal <- function(wide, synth_data, Z, progfunc, scm, fixedeff, V = NULL, ...) { n <- nrow(wide$X) t0 <- ncol(wide$X) + + if ( progfunc == "ridge" && abs( n - t0 ) <= 1 ) { + warning( paste0( "Tuning ridge regression when number of units (", n, ") is almost equal to the number of time periods (", t0, ") is often unstable." ), call. = FALSE ) + } + ttot <- t0 + ncol(wide$y) if(fixedeff) { demeaned <- demean_data(wide, synth_data) @@ -108,33 +124,39 @@ fit_augsynth_internal <- function(wide, synth_data, Z, progfunc, progfunc = "none" } progfunc = tolower(progfunc) + ## fit augsynth if(progfunc == "ridge") { # Ridge ASCM augsynth <- do.call(fit_ridgeaug_formatted, list(wide_data = fit_wide, synth_data = fit_synth_data, - Z = Z, V = V, scm = scm, ...)) + Z = Z, V = V, + ridge = TRUE, scm = scm, ...)) } else if(progfunc == "none") { ## Just SCM + if ( !scm ) { + stop( "Cannot run with progfunc='none' AND no SCM weights" ) + } augsynth <- do.call(fit_ridgeaug_formatted, - c(list(wide_data = fit_wide, - synth_data = fit_synth_data, - Z = Z, ridge = F, scm = T, V = V, ...))) + c(list(wide_data = fit_wide, + synth_data = fit_synth_data, + Z = Z, V = V, + ridge = FALSE, scm = scm, ...))) } else { ## Other outcome models progfuncs = c("ridge", "none", "en", "rf", "gsyn", "mcp", "cits", "causalimpact", "seq2seq") if (progfunc %in% progfuncs) { - augsynth <- fit_augsyn(fit_wide, fit_synth_data, + augsynth <- fit_augsyn(fit_wide, fit_synth_data, progfunc, scm, ...) } else { stop("progfunc must be one of 'EN', 'RF', 'GSYN', 'MCP', 'CITS', 'CausalImpact', 'seq2seq', 'None'") } - + } - augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0), + augsynth$mhat <- mhat + cbind(matrix(0, nrow = n, ncol = t0), augsynth$mhat) augsynth$data <- wide augsynth$data$Z <- Z @@ -168,13 +190,13 @@ predict.augsynth <- function(object, att = F, ...) { # att <- F # } augsynth <- object - + X <- augsynth$data$X y <- augsynth$data$y comb <- cbind(X, y) trt <- augsynth$data$trt mhat <- augsynth$mhat - + m1 <- colMeans(mhat[trt==1,,drop=F]) resid <- (comb[trt==0,,drop=F] - mhat[trt==0,drop=F]) @@ -191,147 +213,81 @@ predict.augsynth <- function(object, att = F, ...) { } -#' Print function for augsynth -#' @param x augsynth object -#' @param ... Optional arguments -#' @export -print.augsynth <- function(x, ...) { - augsynth <- x - - ## straight from lm - cat("\nCall:\n", paste(deparse(augsynth$call), sep="\n", collapse="\n"), "\n\n", sep="") - - ## print att estimates - tint <- ncol(augsynth$data$X) - ttotal <- tint + ncol(augsynth$data$y) - att_post <- predict(augsynth, att = T)[(tint + 1):ttotal] - - cat(paste("Average ATT Estimate: ", - format(round(mean(att_post),3), nsmall = 3), "\n\n", sep="")) -} - -#' Plot function for augsynth -#' @importFrom graphics plot -#' -#' @param x Augsynth object to be plotted -#' @param inf Boolean, whether to get confidence intervals around the point estimates -#' @param cv If True, plot cross validation MSE against hyper-parameter, otherwise plot effects -#' @param ... Optional arguments -#' @export -plot.augsynth <- function(x, inf = T, cv = F, ...) { - # if ("se" %in% names(list(...))) { - # se <- list(...)$se - # } else { - # se <- T - # } - augsynth <- x - - if (cv == T) { - errors = data.frame(lambdas = augsynth$lambdas, - errors = augsynth$lambda_errors, - errors_se = augsynth$lambda_errors_se) - p <- ggplot2::ggplot(errors, ggplot2::aes(x = lambdas, y = errors)) + - ggplot2::geom_point(size = 2) + - ggplot2::geom_errorbar( - ggplot2::aes(ymin = errors, - ymax = errors + errors_se), - width=0.2, size = 0.5) - p <- p + ggplot2::labs(title = bquote("Cross Validation MSE over " ~ lambda), - x = expression(lambda), y = "Cross Validation MSE", - parse = TRUE) - p <- p + ggplot2::scale_x_log10() - - # find minimum and min + 1se lambda to plot - min_lambda <- choose_lambda(augsynth$lambdas, - augsynth$lambda_errors, - augsynth$lambda_errors_se, - F) - min_1se_lambda <- choose_lambda(augsynth$lambdas, - augsynth$lambda_errors, - augsynth$lambda_errors_se, - T) - min_lambda_index <- which(augsynth$lambdas == min_lambda) - min_1se_lambda_index <- which(augsynth$lambdas == min_1se_lambda) - - p <- p + ggplot2::geom_point( - ggplot2::aes(x = min_lambda, - y = augsynth$lambda_errors[min_lambda_index]), - color = "gold") - p + ggplot2::geom_point( - ggplot2::aes(x = min_1se_lambda, - y = augsynth$lambda_errors[min_1se_lambda_index]), - color = "gold") + - ggplot2::theme_bw() - } else { - plot(summary(augsynth, ...), inf = inf) - } -} - - -#' Summary function for augsynth +#' Function to add inference to augsynth object #' @param object augsynth object -#' @param inf Boolean, whether to get confidence intervals around the point estimates #' @param inf_type Type of inference algorithm. Options are #' \itemize{ #' \item{"conformal"}{Conformal inference (default)} #' \item{"jackknife+"}{Jackknife+ algorithm over time periods} #' \item{"jackknife"}{Jackknife over units} +#' \item{"permutation"}{Placebo permutation, raw ATT} +#' \item{"permutation_rstat"}{Placebo permutation, RMSPE adjusted ATT} #' } -#' @param linear_effect Boolean, whether to invert the conformal inference hypothesis test to get confidence intervals for a linear-in-time treatment effect: intercept + slope * time +#' @param linear_effect Boolean, whether to invert the conformal inference hypothesis test to get confidence intervals for a linear-in-time treatment effect: intercept + slope * time #' @param ... Optional arguments for inference, for more details for each `inf_type` see #' \itemize{ #' \item{"conformal"}{`conformal_inf`} #' \item{"jackknife+"}{`time_jackknife_plus`} #' \item{"jackknife"}{`jackknife_se_single`} +#' \item{"permutation"}{`permutation_inf`} #' } #' @export -summary.augsynth <- function(object, inf = T, inf_type = "conformal", - linear_effect = F, - ...) { + +add_inference <- function(object, inf_type = "conformal", linear_effect = F, ...) { augsynth <- object - summ <- list() t0 <- ncol(augsynth$data$X) t_final <- t0 + ncol(augsynth$data$y) - if(inf) { + if (tolower(inf_type) != "none") { + if(inf_type == "jackknife") { att_se <- jackknife_se_single(augsynth) } else if(inf_type == "jackknife+") { att_se <- time_jackknife_plus(augsynth, ...) } else if(inf_type == "conformal") { - att_se <- conformal_inf(augsynth, ...) - # get CIs for linear treatment effects - if(linear_effect) { - att_linear <- conformal_inf_linear(augsynth, ...) - } + att_se <- conformal_inf(augsynth, ...) + + if(linear_effect) { + att_linear <- conformal_inf_linear(augsynth, ...) + } + + } else if (inf_type %in% c('permutation', 'permutation_rstat')) { + if (is.null(augsynth$results$permutations)) { + augsynth <- add_placebo_distribution(augsynth) + } + att_se <- permutation_inf(augsynth, inf_type) } else { stop(paste(inf_type, "is not a valid choice of 'inf_type'")) } att <- data.frame(Time = augsynth$data$time, Estimate = att_se$att[1:t_final]) + rownames(att) <- att$Time + if(inf_type == "jackknife") { - att$Std.Error <- att_se$se[1:t_final] - att_avg_se <- att_se$se[t_final + 1] + att$Std.Error <- att_se$se[1:t_final] + att_avg_se <- att_se$se[t_final + 1] + } else { + att_avg_se <- NA + } + if( length( att_se$att ) > t_final ) { + att_avg <- att_se$att[t_final + 1] } else { - att_avg_se <- NA + att_avg <- mean(att$Estimate[(t0 + 1):t_final]) } - att_avg <- att_se$att[t_final + 1] - if(inf_type %in% c("jackknife+", "nonpar_bs", "t_dist", "conformal")) { + if(inf_type %in% c("jackknife+", "nonpar_bs", "t_dist", "conformal", "permutation", "permutation_rstat")) { att$lower_bound <- att_se$lb[1:t_final] att$upper_bound <- att_se$ub[1:t_final] } - if(inf_type == "conformal") { - att$p_val <- att_se$p_val[1:t_final] + if(inf_type %in% c("conformal", "permutation", "permutation_rstat")) { + att$p_val <- att_se$p_val[1:t_final] } - } else { - t0 <- ncol(augsynth$data$X) - t_final <- t0 + ncol(augsynth$data$y) - att_est <- predict(augsynth, att = T) + # No inference, make table of estimates and NAs for SEs, etc. + att_est <- predict(augsynth, att = TRUE) att <- data.frame(Time = augsynth$data$time, Estimate = att_est) att$Std.Error <- NA @@ -339,62 +295,252 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", att_avg_se <- NA } - summ$att <- att + augsynth$results$att <- att + augsynth$results$average_att <- data.frame(Value = "Average Post-Treatment Effect", + Estimate = att_avg, Std.Error = att_avg_se) - if(inf) { - if(inf_type %in% c("jackknife+")) { - summ$average_att <- data.frame(Value = "Average Post-Treatment Effect", - Estimate = att_avg, Std.Error = att_avg_se) - summ$average_att$lower_bound <- att_se$lb[t_final + 1] - summ$average_att$upper_bound <- att_se$ub[t_final + 1] - summ$alpha <- att_se$alpha - } - if(inf_type == "conformal") { - # summ$average_att$p_val <- att_se$p_val[t_final + 1] - # summ$average_att$lower_bound <- att_se$lb[t_final + 1] - # summ$average_att$upper_bound <- att_se$ub[t_final + 1] - # summ$alpha <- att_se$alpha - if(linear_effect) { - summ$average_att <- data.frame( - Value = c("Average Post-Treatment Effect", - "Treatment Effect Intercept", - "Treatment Effect Slope"), - Estimate = c(att_avg, att_linear$est_int, - att_linear$est_slope), - Std.Error = c(att_avg_se, NA, NA), - p_val = c(att_se$p_val[t_final + 1], NA, NA), - lower_bound = c(att_se$lb[t_final + 1], - att_linear$ci_int[1], - att_linear$ci_slope[1]), - upper_bound = c(att_se$ub[t_final + 1], - att_linear$ci_int[2], - att_linear$ci_slope[2]) - ) - } else { - summ$average_att <- data.frame( - Value = c("Average Post-Treatment Effect"), - Estimate = att_avg, - Std.Error = att_avg_se, - p_val = att_se$p_val[t_final + 1], - lower_bound = att_se$lb[t_final + 1], - upper_bound = att_se$ub[t_final + 1] - ) + if(inf_type %in% c("jackknife+", "conformal", "permutation", "permutation_rstat")) { + augsynth$results$average_att$lower_bound <- att_se$lb[t_final + 1] + augsynth$results$average_att$upper_bound <- att_se$ub[t_final + 1] + augsynth$results$alpha <- att_se$alpha + + if (inf_type == 'conformal') { + if(linear_effect) { + augsynth$results$average_att <- data.frame( + Value = c("Average Post-Treatment Effect", + "Treatment Effect Intercept", + "Treatment Effect Slope"), + Estimate = c(att_avg, att_linear$est_int, + att_linear$est_slope), + Std.Error = c(att_avg_se, NA, NA), + p_val = c(att_se$p_val[t_final + 1], NA, NA), + lower_bound = c(att_se$lb[t_final + 1], + att_linear$ci_int[1], + att_linear$ci_slope[1]), + upper_bound = c(att_se$ub[t_final + 1], + att_linear$ci_int[2], + att_linear$ci_slope[2]) + ) + } else { + augsynth$results$average_att <- data.frame( + Value = c("Average Post-Treatment Effect"), + Estimate = att_avg, + Std.Error = att_avg_se, + p_val = att_se$p_val[t_final + 1], + lower_bound = att_se$lb[t_final + 1], + upper_bound = att_se$ub[t_final + 1] + ) + + } } - summ$alpha <- att_se$alpha - } - } else { - summ$average_att <- data.frame(Value = "Average Post-Treatment Effect", - Estimate = att_avg, Std.Error = att_avg_se) } - summ$t_int <- augsynth$t_int - summ$call <- augsynth$call - summ$l2_imbalance <- augsynth$l2_imbalance - summ$scaled_l2_imbalance <- augsynth$scaled_l2_imbalance - if(!is.null(augsynth$covariate_l2_imbalance)) { - summ$covariate_l2_imbalance <- augsynth$covariate_l2_imbalance - summ$scaled_covariate_l2_imbalance <- augsynth$scaled_covariate_l2_imbalance + if(inf_type %in% c("conformal", "permutation", "permutation_rstat")) { + augsynth$results$average_att$p_val <- att_se$p_val[t_final + 1] } + + augsynth$results$inf_type <- inf_type + + return(augsynth) + +} + +#' Print function for augsynth +#' @param x augsynth object +#' @param ... Optional arguments +#' @export +print.augsynth <- function(x, ...) { + augsynth <- x + + ## straight from lm + cat("\nCall:\n", paste(deparse(augsynth$call), sep="\n", collapse="\n"), "\n", sep="") + cat( sprintf( " Fit to %d units and %d time points\n\n", n_unit(augsynth), n_time(augsynth) ) ) + ## print att estimates + tint <- ncol(augsynth$data$X) + ttotal <- tint + ncol(augsynth$data$y) + att_post <- predict(augsynth, att = T)[(tint + 1):ttotal] + + cat(paste("Average ATT Estimate: ", + format(round(mean(att_post),3), nsmall = 3), "\n\n", sep="")) +} + +#' Plot function for augsynth +#' @importFrom graphics plot +#' +#' +#' @param augsynth Augsynth or summary.augsynth object to be plotted +#' @param plot_type The stylized plot type to be returned. Options include +#' \itemize{ +#' \item{"estimate"}{The ATT and 95\% confidence interval} +#' \item{"estimate only"}{The ATT without a confidence interval} +#' \item{"outcomes"}{The level of the outcome variable for the treated and synthetic control units.} +#' \item{"outcomes raw average"}{The level of the outcome variable for the treated and synthetic control units, along with the raw average of the donor units.} +#' \item{"placebo"}{The ATTs resulting from placebo tests on the donor units.} } +#' @param cv If True, plot cross validation MSE against hyper-parameter, otherwise plot effects +#' @param inf_type Type of inference algorithm. Inherits inf_type from `object` or otherwise defaults to "conformal". Options are +#' \itemize{ +#' \item{"conformal"}{Conformal inference (default)} +#' \item{"jackknife+"}{Jackknife+ algorithm over time periods} +#' \item{"jackknife"}{Jackknife over units} +#' \item{"permutation"}{Placebo permutation, raw ATT} +#' \item{"permutation_rstat"}{Placebo permutation, RMSPE adjusted ATT} +#' \item{"None"}{Return ATT Estimate only} +#' } +#' @param ... Optional arguments for inference, for more details for each `inf_type` see +#' \itemize{ +#' \item{"conformal"}{`conformal_inf`} +#' \item{"jackknife+"}{`time_jackknife_plus`} +#' \item{"jackknife"}{`jackknife_se_single`} +#' \item{"permutation"}{`permutation_inf`} +#' } +#' @param ... Optional arguments +#' @export +plot.augsynth <- function(augsynth, + cv = FALSE, + plot_type = 'estimate', + inf_type = NULL, ...) { + + plot_augsynth_results( augsynth=augsynth, cv=cv, plot_type=plot_type, inf_type=inf_type, ... ) + +} + +#' Methods for accessing details of augsynth result object (of class augsynth) +#' +#' +#' @param x augsynth result object +#' +#' @rdname augsynth_class +#' +#' @return is.augsynth: TRUE if object is a augsynth object. +#' +#' @export +is.augsynth <- function(x) { + inherits(x, "augsynth") +} + + + +#' +#' @return dim: Dimension of data as pair of (# units, # time points). +#' +#' @rdname augsynth_class +#' @export +dim.augsynth <- function(x, ... ) { + n_unit = length( unique( x$raw_data[[ x$unit_var ]] ) ) + n_time = length( unique( x$raw_data[[ x$time_var ]] ) ) + return( c( n_unit, n_time ) ) +} + + +#' +#' @return Single number (of unique units). +#' +#' @rdname synth_class +#' @export +n_unit <- function(x, ... ) { + UseMethod( "n_unit" ) +} + +#' @title Number of time points in fit data +#' +#' @rdname synth_class +#' @return Single number (of unique time points). +#' @export +n_time <- function(x, ... ) { + UseMethod( "n_time" ) +} + + + +#' @title Number of treated units in fit data +#' +#' @rdname synth_class +#' @return Single number (of number of treated units). +#' @export +n_treated <- function(x, ... ) { + UseMethod( "n_treated" ) +} + + + +#' +#' @return Single number (of unique units). +#' +#' @rdname augsynth_class +#' +#' @export +#' +n_unit.augsynth <- function(x, ... ) { + dim.augsynth(x)[[1]] +} + +#' @title Number of time points in augsynth +#' +#' @rdname augsynth_class +#' +#' @return Single number (of unique time points). +#' @export +n_time.augsynth <- function(x, ... ) { + dim.augsynth(x)[[2]] +} + + +#' +#' @rdname augsynth_class +#' +#' @return Number of treated units (always 1 for augsynth) +#' @export +n_treated.augsynth <- function(x, ... ) { + return( 1 ) +} + + +#### Summary methods #### + + +#' Summary function for augsynth +#' +#' Summary summarizes an augsynth result by (usually) adding an +#' inferential result, if that has not been calculated already, and +#' calculating a few other summary statistics such as estimated bias. +#' This method does this via `add_inference()`, if inference is +#' needed. +#' +#' @param object augsynth object +#' +#' @param inf_type Type of inference algorithm. If left NULL, inherits +#' inf_type from `object` or otherwise defaults to "conformal." +#' Options are +#' \itemize{ +#' \item{"conformal"}{Conformal inference (default)} +#' \item{"jackknife+"}{Jackknife+ algorithm over time periods} +#' \item{"jackknife"}{Jackknife over units} +#' \item{"permutation"}{Placebo permutation, raw ATT} +#' \item{"permutation_rstat"}{Placebo permutation, RMSPE adjusted ATT} +#' } +#' @param ... Optional arguments for inference, for more details for +#' each `inf_type` see +#' \itemize{ +#' \item{"conformal"}{`conformal_inf`} +#' \item{"jackknife+"}{`time_jackknife_plus`} +#' \item{"jackknife"}{`jackknife_se_single`} +#' \item{"permutation", "permutation_rstat"}{`permutation_inf`} +#' } +#' @export +summary.augsynth <- function(object, inf_type = 'conformal', ...) { + augsynth <- object + + t0 <- ncol(augsynth$data$X) + t_final <- t0 + ncol(augsynth$data$y) + + augsynth <- add_inference(augsynth, inf_type = inf_type) + + # Copy over all of OG object except for data + nms = names(augsynth) + nms = nms[!nms %in% c("data", "raw_data", "results")] + summ <- augsynth$results + summ <- c( summ, augsynth[nms] ) + ## get estimated bias if(tolower(augsynth$progfunc) == "ridge") { @@ -410,32 +556,62 @@ summary.augsynth <- function(object, inf = T, inf_type = "conformal", if(tolower(augsynth$progfunc) == "none" | (!augsynth$scm)) { summ$bias_est <- NA } else { - summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w + summ$bias_est <- m1 - t(mhat[trt==0,,drop=F]) %*% w } - - - summ$inf_type <- if(inf) inf_type else "None" + + summ$n_unit <- n_unit(augsynth) + summ$n_time <- n_time(augsynth) + summ$n_tx <- n_treated(augsynth)[1] + summ$time_tx <- t0 + summ$donor_table <- donor_table( augsynth ) + + summ$treated_table <- treated_table( augsynth ) + class(summ) <- "summary.augsynth" return(summ) } + + +#' Methods for accessing details of summary.augsynth object +#' +#' @param x summary.augsynth result object +#' +#' @rdname summary.augsynth_class +#' +#' @return is_summary_augsynth: TRUE if object is a augsynth object. +#' +#' @export +is_summary_augsynth <- function(x) { + inherits(x, "summary.augsynth") +} + + + #' Print function for summary function for augsynth +#' #' @param x summary object #' @param ... Optional arguments #' @export print.summary.augsynth <- function(x, ...) { summ <- x - + ## straight from lm - cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n\n", sep="") + cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n", sep="") t_final <- nrow(summ$att) + cat( sprintf( "\nFit to %d units and %d+%d = %d time points; %g treated at %s %g.\n", + summ$n_unit, summ$time_tx, t_final - summ$time_tx, t_final, + summ$n_tx, + summ$time_var, + summ$att$Time[[summ$time_tx+1]]) ) + cat( "\n" ) ## distinction between pre and post treatment att_est <- summ$att$Estimate t_total <- length(att_est) t_int <- summ$att %>% filter(Time <= summ$t_int) %>% nrow() - + att_pre <- att_est[1:(t_int-1)] att_post <- att_est[t_int:t_total] @@ -447,20 +623,22 @@ print.summary.augsynth <- function(x, ...) { att_post <- summ$average_att$Estimate[1] se_est <- summ$att$Std.Error if(summ$inf_type == "jackknife") { - se_avg <- summ$average_att$Std.Error + se_avg <- summ$average_att$Std.Error - out_msg <- paste("Average ATT Estimate (Jackknife Std. Error): ", - format(round(att_post,3), nsmall=3), - " (", - format(round(se_avg,3)), ")\n") - inf_type <- "Jackknife over units" + out_msg <- paste("Average ATT Estimate (Jackknife Std. Error): ", + format(round(att_post,3), nsmall=3), + " (", + format(round(se_avg,3)), ")\n") + inf_type <- "Jackknife over units" } else if(summ$inf_type == "conformal") { - p_val <- summ$average_att$p_val[1] - out_msg <- paste("Average ATT Estimate (p Value for Joint Null): ", - format(att_post, digits = 3), - " (", - format(p_val, digits = 2), ")\n") - inf_type <- "Conformal inference" + + p_val <- summ$average_att$p_val + out_msg <- paste("Average ATT Estimate (p Value for Joint Null): ", + format(att_post, digits = 3), + " (", + format(p_val, digits = 2), ")\n") + inf_type <- "Conformal inference" + if("Treatment Effect Slope" %in% summ$average_att$Value) { lowers <- summ$average_att$lower_bound[2:3] uppers <- summ$average_att$upper_bound[2:3] @@ -471,112 +649,444 @@ print.summary.augsynth <- function(x, ...) { format(uppers[2], digits = 3), "]\n") out_msg <- paste0(out_msg, out_msg_line2) } + } else if(summ$inf_type == "jackknife+") { - out_msg <- paste("Average ATT Estimate: ", - format(round(att_post,3), nsmall=3), "\n") - inf_type <- "Jackknife+ over time periods" + out_msg <- paste("Average ATT Estimate: ", + format(round(att_post,3), nsmall=3), "\n") + inf_type <- "Jackknife+ over time periods" + } else if (summ$inf_type %in% c('permutation', "permutation_rstat")) { + out_msg <- paste("Average ATT Estimate: ", + format(round(att_post,3), nsmall=3), "\n") + inf_type <- ifelse(summ$inf_type == 'permutation', + "Permutation inference", + "Permutation inference (RMSPE-adjusted)") + out_msg <-paste0( out_msg, "\n", + ( sprintf( "Donor RMSPE range from %.2f to %.2f\n", + min( summ$donor_table$RMSPE ), max( summ$donor_table$RMSPE ) ) ) ) + } else { - out_msg <- paste("Average ATT Estimate: ", - format(round(att_post,3), nsmall=3), "\n") - inf_type <- "None" + out_msg <- paste("Average ATT Estimate: ", + format(round(att_post,3), nsmall=3), "\n") + inf_type <- "None" } - out_msg <- paste(out_msg, - "L2 Imbalance: ", - format(round(summ$l2_imbalance,3), nsmall=3), "\n", - "Percent improvement from uniform weights: ", - format(round(1 - summ$scaled_l2_imbalance,3)*100), "%\n\n", - sep="") - if(!is.null(summ$covariate_l2_imbalance)) { - out_msg <- paste(out_msg, - "Covariate L2 Imbalance: ", - format(round(summ$covariate_l2_imbalance,3), - nsmall=3), - "\n", + "L2 Imbalance: ", + format(round(summ$l2_imbalance,3), nsmall=3), "\n", "Percent improvement from uniform weights: ", - format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100), - "%\n\n", + format(round(1 - summ$scaled_l2_imbalance,3)*100), "%\n\n", + sep="") + + if(!is.null(summ$covariate_l2_imbalance)) { + + out_msg <- paste(out_msg, + "Covariate L2 Imbalance: ", + format(round(summ$covariate_l2_imbalance,3), + nsmall=3), + "\n", + "Percent improvement from uniform weights: ", + format(round(1 - summ$scaled_covariate_l2_imbalance,3)*100), + "%\n\n", + sep="") + + } + out_msg <- paste(out_msg, + "Avg Estimated Bias: ", + format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n", + "Inference type: ", + inf_type, + "\n\n", sep="") - } - out_msg <- paste(out_msg, - "Avg Estimated Bias: ", - format(round(mean(summ$bias_est), 3),nsmall=3), "\n\n", - "Inference type: ", - inf_type, - "\n\n", - sep="") - cat(out_msg) + rng = range( summ$donor_table$weight[ summ$donor_table$weight > 1/(1000*summ$n_unit) ] ) + cat( sprintf( "%d donor units used with weights of %.3f to %.3f\n", + sum( abs(summ$donor_table$weight) > 1/(1000*summ$n_unit) ), rng[[1]], rng[[2]] ) ) + cat(out_msg) + if(summ$inf_type == "jackknife") { - out_att <- summ$att[t_int:t_final,] %>% - select(Time, Estimate, Std.Error) + out_att <- summ$att[t_int:t_final,] %>% + select(Time, Estimate, Std.Error) } else if(summ$inf_type == "conformal") { - out_att <- summ$att[t_int:t_final,] %>% - select(Time, Estimate, lower_bound, upper_bound, p_val) - names(out_att) <- c("Time", "Estimate", - paste0((1 - summ$alpha) * 100, "% CI Lower Bound"), - paste0((1 - summ$alpha) * 100, "% CI Upper Bound"), - paste0("p Value")) + out_att <- summ$att[t_int:t_final,] %>% + select(Time, Estimate, lower_bound, upper_bound, p_val) + names(out_att) <- c("Time", "Estimate", + paste0((1 - summ$alpha) * 100, "% CI Lower Bound"), + paste0((1 - summ$alpha) * 100, "% CI Upper Bound"), + paste0("p Value")) } else if(summ$inf_type == "jackknife+") { - out_att <- summ$att[t_int:t_final,] %>% - select(Time, Estimate, lower_bound, upper_bound) - names(out_att) <- c("Time", "Estimate", - paste0((1 - summ$alpha) * 100, "% CI Lower Bound"), - paste0((1 - summ$alpha) * 100, "% CI Upper Bound")) + out_att <- summ$att[t_int:t_final,] %>% + select(Time, Estimate, lower_bound, upper_bound) + names(out_att) <- c("Time", "Estimate", + paste0((1 - summ$alpha) * 100, "% CI Lower Bound"), + paste0((1 - summ$alpha) * 100, "% CI Upper Bound")) + } else if (summ$inf_type %in% c('permutation', "permutation_rstat")) { + out_att <- summ$att[t_int:t_final, ] %>% + select(Time, Estimate, lower_bound, upper_bound, p_val) + names(out_att) <- c("Time", "Estimate", + paste0((1 - summ$alpha) * 100, "% CI Lower Bound"), + paste0((1 - summ$alpha) * 100, "% CI Upper Bound"), + paste0('p Value')) } else { - out_att <- summ$att[t_int:t_final,] %>% - select(Time, Estimate) + out_att <- summ$att[t_int:t_final,] %>% + select(Time, Estimate) } out_att %>% - mutate_at(vars(-Time), ~ round(., 3)) %>% - print(row.names = F) + mutate_at(vars(-Time), ~ round(., 3)) %>% + print(row.names = F) +} + + + +#' Plot results from summarized augsynth +#' +#' Make a variety of plots, depending on plot_type. Default is to +#' plot impacts with associated uncertainty (if present). Other +#' options are estimates with no uncertainty ("estimate only"), the +#' level of outcome ("outcomes"), level of outcomes with the raw trend +#' as well ("outcomes raw average"), and the classic spagetti plot +#' ("placebo"). +#' +#' @param x summary.augsynth object +#' @inheritParams plot_augsynth_results +#' +#' @export +plot.summary.augsynth <- function(x, + plot_type = 'estimate', + ...) { + + summ <- x + + plot_augsynth_results( summ, plot_type = plot_type, ... ) - } + + +#### Plotting helper functions #### + + + + +#' Figure out what kind of quick-summary is needed given the desired +#' plot type. +#' +#' @noRd +get_right_summary <- function( augsynth, plot_type, inf_type ) { + + prior_inf = NULL + if ( !is.null(augsynth$results) ) { + prior_inf = augsynth$results$inf_type + } else { + prior_inf = "none" + } + + if ( plot_type=="estimates only" ) { + inf_type = prior_inf + } else if (plot_type == 'placebo') { + if ( prior_inf %in% c('permutation', 'permutation_rstat') ) { + inf_type = prior_inf + } else if ( is.null( inf_type ) && prior_inf == "none" ) { + # if the user specifies the "placebo" plot type without + # accompanying inference, default to placebo + inf_type = "permutation" + } else if (!inf_type %in% c('permutation', 'permutation_rstat')) { + message('Placebo plots are only available for permutation-based inference. The plot shows results from "inf_type = "permutation""') + inf_type = "permutation" + } + } else if (plot_type %in% c( "outcomes", "outcomes raw average" )) { + if ( is.null( inf_type ) ) { + inf_type = "none" + } + } else if ( is.null( prior_inf ) ) { + inf_type = "conformal" + } + + if ( is.null( inf_type ) || inf_type == "none" ) { + inf_type = prior_inf + } + + inf_type +} + + + + +#' Plot function for augsynth or summary.augsynth objects +#' +#' @importFrom graphics plot +#' +#' @param augsynth augsynth or summary.augsynth object to be plotted +#' @param plot_type The stylized plot type to be returned. Options include +#' \itemize{ +#' \item{"cv"}{Cross-validation diagnostic plot} +#' \item{"estimate"}{The ATT and 95\% confidence interval} +#' \item{"estimate only"}{The ATT without a confidence interval} +#' \item{"outcomes"}{The level of the outcome variable for the treated and synthetic control units.} +#' \item{"outcomes raw average"}{The level of the outcome variable for the treated and synthetic control units, along with the raw average of the donor units.} +#' \item{"placebo"}{The ATTs resulting from placebo tests on the donor units.} } +#' @param cv If True, plot cross validation MSE against hyper-parameter, otherwise plot effects +#' @param inf_type Type of inference algorithm. Inherits inf_type from `object` or otherwise defaults to "conformal". Options are +#' \itemize{ +#' \item{"conformal"}{Conformal inference (default)} +#' \item{"jackknife+"}{Jackknife+ algorithm over time periods} +#' \item{"jackknife"}{Jackknife over units} +#' \item{"permutation"}{Placebo permutation, raw ATT} +#' \item{"permutation_rstat"}{Placebo permutation, RMSPE adjusted ATT} +#' \item{"None"}{Return ATT Estimate only} +#' } +#' @param ... Optional arguments for inference, for more details for each `inf_type` see +#' \itemize{ +#' \item{"conformal"}{`conformal_inf`} +#' \item{"jackknife+"}{`time_jackknife_plus`} +#' \item{"jackknife"}{`jackknife_se_single`} +#' \item{"permutation"}{`permutation_inf`} +#' } +#' @param ... Optional arguments +plot_augsynth_results <- function( augsynth, + plot_type = 'estimate', + inf_type = NULL, ...) { + + if ( !is.null( inf_type ) ) { + inf_type = tolower(inf_type) + stopifnot( inf_type %in% c('conformal', 'jackknife', 'jackknife+', 'permutation', 'permutation_rstat', 'none')) + } + + # Summarize object if needed. + if ( is.augsynth(augsynth) ) { + it <- get_right_summary(augsynth, plot_type, inf_type) + message( + "Plotting augsynth objects may be slow. For faster results, first create a summary object ", + "and plot that object directly (e.g., s <- summary(augsynth_obj); plot(s))." + ) + augsynth = summary(augsynth, inf_type=it) + } + + if (plot_type == "cv") { + errors = data.frame(lambdas = augsynth$lambdas, + errors = augsynth$lambda_errors, + errors_se = augsynth$lambda_errors_se) + p <- ggplot2::ggplot(errors, ggplot2::aes(x = lambdas, y = errors)) + + ggplot2::geom_point(size = 2) + + ggplot2::geom_errorbar( + ggplot2::aes(ymin = errors, + ymax = errors + errors_se), + width = 0.2, linewidth = 0.5) + p <- p + ggplot2::labs(title = bquote("Cross Validation MSE over " ~ lambda), + x = expression(lambda), y = "Cross Validation MSE", + parse = TRUE) + p <- p + ggplot2::scale_x_log10() + + # find minimum and min + 1se lambda to plot + min_lambda <- choose_lambda(augsynth$lambdas, + augsynth$lambda_errors, + augsynth$lambda_errors_se, + F) + min_1se_lambda <- choose_lambda(augsynth$lambdas, + augsynth$lambda_errors, + augsynth$lambda_errors_se, + T) + min_lambda_index <- which(augsynth$lambdas == min_lambda) + min_1se_lambda_index <- which(augsynth$lambdas == min_1se_lambda) + + p <- p + ggplot2::geom_point( + ggplot2::aes(x = min_lambda, + y = augsynth$lambda_errors[min_lambda_index]), + color = "gold") + p + ggplot2::geom_point( + ggplot2::aes(x = min_1se_lambda, + y = augsynth$lambda_errors[min_1se_lambda_index]), + color = "gold") + + ggplot2::theme_bw() + return(p) + } else if (plot_type == 'estimate only') { + p <- augsynth_plot_from_results(augsynth, ci = FALSE) + } else if (plot_type == 'estimate') { + p <- augsynth_plot_from_results(augsynth, ci = TRUE) + } else if (grepl('placebo', plot_type)) { + p <- permutation_plot(augsynth, inf_type = augsynth$inf_type) + } else if (plot_type == 'outcomes') { + p <- augsynth_outcomes_plot(augsynth, measure = 'synth') + } else if (plot_type == 'outcomes raw average') { + p <- augsynth_outcomes_plot(augsynth, measure = c('synth', 'average')) + } + return(p) +} + + + + #' Plot function for summary function for augsynth -#' @param x Summary object -#' @param inf Boolean, whether to plot confidence intervals +#' +#' @param augsynth Summary object #' @param ... Optional arguments -#' @export -plot.summary.augsynth <- function(x, inf = T, ...) { - summ <- x - # if ("inf" %in% names(list(...))) { - # inf <- list(...)$inf - # } else { - # inf <- T - # } - - p <- summ$att %>% - ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate)) - if(inf) { - if(all(is.na(summ$att$lower_bound))) { - p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=Estimate-2*Std.Error, - ymax=Estimate+2*Std.Error), - alpha=0.2) +#' +#' @noRd +augsynth_plot_from_results <- function(augsynth, + ci = TRUE, + ...) { + + pdat = NA + if ( is.augsynth(augsynth) ) { + pdat <- augsynth$results$att + } else { + # Summary object + pdat <- augsynth$att + } + + p <- pdat %>% + ggplot2::ggplot(ggplot2::aes(x = Time, y = Estimate)) + + if (ci) { + if(all(is.na(pdat$lower_bound))) { + p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin = Estimate - 2 * Std.Error, + ymax = Estimate + 2 * Std.Error), + alpha = 0.2) } else { - p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin=lower_bound, - ymax=upper_bound), - alpha=0.2) + p <- p + ggplot2::geom_ribbon(ggplot2::aes(ymin = lower_bound, + ymax = upper_bound), + alpha = 0.2) } } - p + ggplot2::geom_line() + - ggplot2::geom_vline(xintercept=summ$t_int, lty=2) + - ggplot2::geom_hline(yintercept=0, lty=2) + + p <- p + ggplot2::geom_line() + + ggplot2::geom_vline(xintercept = augsynth$t_int, lty = 2) + + ggplot2::geom_hline(yintercept = 0, lty = 2) + + ggplot2::labs(x = augsynth$time_var) + ggplot2::theme_bw() + return(p) + +} + + + +#' Plot the original level of the outcome variable for the treated +#' unit and its synthetic counterfactual +#' +#' @param augsynth Augsynth object or augsynth summary object to be plotted +#' @param measure Whether to plot the synthetic counterfactual or the +#' raw average of donor units. Can list both if desired. +#' +#' @noRd +augsynth_outcomes_plot <- function(augsynth, measure = c("synth", "average")) { + + if (!is_summary_augsynth(augsynth)) { + augsynth <- summary(augsynth) + } + + series = augsynth$treated_table + + all_y = c(series$Yobs, series$Yhat, series$raw_average) + max_y <- max(all_y) + min_y <- min(all_y) + + pt = which( series$tx == 1 )[[1]] + cut_time = (series$time[pt] + series$time[pt + 1]) / 2 + + p <- ggplot2::ggplot( series ) + + ggplot2::geom_line(aes(x = time, y = Yobs, linetype = as.character(augsynth$trt_unit))) + + if ('synth' %in% measure) { + p <- p + + ggplot2::geom_line(aes(x = time, y = Yhat, linetype = 'Synthetic counterfactual')) + } + + if ('average' %in% measure) { + p <- p + + ggplot2::geom_line(aes(x = time, y = raw_average, linetype = 'Donor raw average')) + } + + p <- p + + ggplot2::labs(linetype = NULL, + x = augsynth$time_var, + y = 'Outcome') + + ggplot2::ylim(min_y, max_y) + + ggplot2::theme_bw() + + ggplot2::geom_vline(xintercept = cut_time, linetype = 'dashed') + + ggplot2::theme(legend.position = 'bottom', + legend.key = ggplot2::element_rect(fill = scales::alpha("white", 0.5)), + legend.background = ggplot2::element_rect(fill = scales::alpha("white", 0)) + ) + + return(p) } + +#' Generate permutation plots +#' +#' @param results Results from calling augsynth() +#' +#' @param inf_type Inference type (takes a value of 'permutation' or +#' 'permutation_rstat') Type of inference algorithm. Inherits +#' inf_type from `object` or otherwise defaults to "conformal". +#' Options are +#' \itemize{ +#' \item{"If numeric"}{A multiple of the treated unit's RMSPE +#' above which donor units will be dropped} +#' \item{"If a character"}{The name or names of donor units to +#' be dropped based on the `unit` parameter +#' in the augsynth model} +#' } +#' @export +permutation_plot <- function(augsynth, inf_type = 'permutation') { + + if (!inf_type %in% c('permutation', 'permutation_rstat')) { + stop("Permutation plots are only available for `permutation` and `permutation_rstat` inference types") + } + + if(inf_type == 'permutation') { + measure = "ATT" + y_lab = "Estimate (ATT)" + } else { + measure = 'rstat' + y_lab = "Estimate (RMPSE-adjusted ATT)" + } + + if (is_summary_augsynth(augsynth)) { + placebo_dist <- augsynth$permutations$placebo_dist + } else if (is.null(augsynth$results$permutations)) { + augsynth <- add_placebo_distribution(augsynth) + placebo_dist <- augsynth$results$permutations$placebo_dist + } else { + placebo_dist <- augsynth$results$permutations$placebo_dist + } + + plot_df <- placebo_dist %>% + mutate(trt_status = factor(trt, levels = c(0, 1), labels = c('Control', 'Treatment'))) + + t0 <- ncol(augsynth$data$X) + treat_year = augsynth$data$time[t0 + 1] + + out_plot <- ggplot2::ggplot(plot_df, + aes(x = !!as.name(augsynth$time_var), y = !!as.name(measure), + color = trt_status, linetype = !!as.name(augsynth$unit_var))) + + ggplot2::geom_line() + + ggplot2::geom_vline(lty = 2, xintercept = treat_year) + + ggplot2::geom_hline(lty = 2, yintercept = 0) + + ggplot2::scale_color_manual(values = c('Control' = 'gray', 'Treatment' = 'black')) + + ggplot2::scale_linetype_manual(values = rep('solid', length(unique(plot_df %>% pull(!!as.name(augsynth$unit_var)))))) + + ggplot2::labs(color = NULL, y = y_lab) + + ggplot2::guides(linetype = 'none') + + ggplot2::theme_bw() + + ggplot2::theme(legend.position = 'bottom') + + return(out_plot) +} + + + + + +#### Package documentation #### + + #' augsynth -#' +#' #' @description A package implementing the Augmented Synthetic Controls Method -#' @docType package #' @name augsynth-package #' @importFrom magrittr "%>%" #' @importFrom purrr reduce @@ -584,9 +1094,10 @@ plot.summary.augsynth <- function(x, inf = T, ...) { #' @import tidyr #' @importFrom stats terms #' @importFrom stats formula -#' @importFrom stats update -#' @importFrom stats delete.response -#' @importFrom stats model.matrix -#' @importFrom stats model.frame +#' @importFrom stats update +#' @importFrom stats delete.response +#' @importFrom stats model.matrix +#' @importFrom stats model.frame #' @importFrom stats na.omit -NULL +"_PACKAGE" + diff --git a/R/augsynth_pre.R b/R/augsynth_pre.R index a9dfe14..8593d2a 100644 --- a/R/augsynth_pre.R +++ b/R/augsynth_pre.R @@ -1,21 +1,29 @@ + ################################################################################ ## Main function for the augmented synthetic controls Method ################################################################################ #' Fit Augmented SCM +#' +#' The general `augsynth()` method dispatches to `single_augsynth`, +#' `augsynth_multiout`, or `multisynth` based on the number of +#' outcomes and treatment times. See documentation for these methods +#' for further detail. +#' #' @param form outcome ~ treatment | auxillary covariates #' @param unit Name of unit column #' @param time Name of time column #' @param data Panel data as dataframe -#' @param t_int Time of intervention (used for single-period treatment only) +#' @param t_int Time of intervention (used for single-period treatment +#' only) #' @param ... Optional arguments #' \itemize{ #' \item Single period augsynth with/without multiple outcomes #' \itemize{ #' \item{"progfunc"}{What function to use to impute control outcomes: Ridge=Ridge regression (allows for standard errors), None=No outcome model, EN=Elastic Net, RF=Random Forest, GSYN=gSynth, MCP=MCPanel, CITS=CITS, CausalImpact=Bayesian structural time series with CausalImpact, seq2seq=Sequence to sequence learning with feedforward nets} #' \item{"scm"}{Whether the SCM weighting function is used} -#' \item{"fixedeff"}{Whether to include a unit fixed effect, default F } +#' \item{"fixedeff"}{Whether to include a unit fixed effect, default is FALSE } #' \item{"cov_agg"}{Covariate aggregation functions, if NULL then use mean with NAs omitted} #' } #' \item Multi period (staggered) augsynth @@ -29,14 +37,16 @@ #' \item{"n_factors"}{Number of factors for interactive fixed effects, default does CV} #' } #' } -#' -#' @return augsynth object that contains: +#' +#' @return augsynth or multisynth object (depending on dispatch) that contains (among other things): #' \itemize{ #' \item{"weights"}{weights} #' \item{"data"}{Panel data as matrices} #' } +#' +#' @seealso `single_augsynth`, `augsynth_multiout`, `multisynth` #' @export -#' +#' augsynth <- function(form, unit, time, data, t_int=NULL, ...) { call_name <- match.call() @@ -44,7 +54,7 @@ augsynth <- function(form, unit, time, data, t_int=NULL, ...) { form <- Formula::Formula(form) unit_quosure <- enquo(unit) time_quosure <- enquo(time) - + ## format data outcome <- terms(formula(form, rhs=1))[[2]] @@ -69,7 +79,7 @@ augsynth <- function(form, unit, time, data, t_int=NULL, ...) { if("progfunc" %in% names(list(...))) { warning("`progfunc` is not an argument for multisynth, so it is ignored") } - return(multisynth(form, !!enquo(unit), !!enquo(time), data, ...)) + return(multisynth(form, !!enquo(unit), !!enquo(time), data, ...)) } else { if (is.null(t_int)) { t_int <- trt_time %>% filter(is.finite(trt_time)) %>% diff --git a/R/calc_covariate_balance.R b/R/calc_covariate_balance.R new file mode 100644 index 0000000..ecc38d3 --- /dev/null +++ b/R/calc_covariate_balance.R @@ -0,0 +1,51 @@ + + +#' Make covariate balance table. +#' +#' Make a table comparing means of covariates in the treated group, +#' the raw control group, and the new weighted control group (the +#' synthetic control) +#' +#' @param ascm An augsynth result object from single_augsynth +#' @param pre_period List of names of the pre-period timepoints to +#' calculate balance for. NULL means none. +#' +#' @export +covariate_balance_table = function( ascm, pre_period = NULL ) { + + stopifnot( is.augsynth( ascm ) ) + + trt = ascm$data$trt + weight = rep( 0, length( trt ) ) + weight[ trt == 0 ] = ascm$synw + stopifnot( abs( sum( weight ) - 1 ) < 0.000001 ) + + Z = ascm$data$Z + if ( !is.null( pre_period ) ) { + xx <- ascm$data$X[ , colnames(ascm$data$X) %in% pre_period, drop = FALSE] + if ( ncol( xx ) > 0 ) { + Z = cbind( Z, xx ) + } + } + + Co_means = t( Z ) %*% weight + + # Means of the outcome at lagged time points + #Co_means_2 = ascm$data$synth_data$Z0 %*% ascm$synw + + # Unform weighting + n_donor = length(ascm$synw) + unit_weight = rep( 1 / n_donor, nrow(Z) ) + unit_weight[ trt == 1 ] = 0 + raw_means = t( Z ) %*% unit_weight + + Tx_means = Z[ trt == 1, ] + + means = tibble( variable = names( Tx_means ), + Tx = as.numeric( Tx_means ), + Co = as.numeric( Co_means ), + Raw = as.numeric( raw_means ) ) + + + means +} diff --git a/R/donor_control.R b/R/donor_control.R new file mode 100644 index 0000000..e986c97 --- /dev/null +++ b/R/donor_control.R @@ -0,0 +1,119 @@ + +# This file contains methods to monitor and modify set of donor units +# in augmented synthetic control method + + +#' Return a summary data frame donor units used in the model with +#' their synthetic weights. +#' +#' If permutation inference has been conducted, table will also +#' include RMSPEs. This can be forced with include_RMSPE flag. +#' +#' If the augsynth object does not have permutation-based inference +#' results, the function will call that form of inference, in order to +#' calculate the RMSPEs for each donor unit in turn. +#' +#' @param augsynth Augsynth object to be plotted +#' @param include_RMSPE Include RMSPEs in the table even if +#' permutation inference has not yet been conducted. +#' @param zap_weights all weights smaller than this value will be set +#' to zero. Set to NULL to keep all weights. +#' @export +donor_table <- function(augsynth, include_RMSPE = TRUE, zap_weights = 0.0000001 ) { + + if ( is_summary_augsynth( augsynth ) ) { + if ( !is.null( augsynth$donor_table ) ) { + return( augsynth$donor_table ) + } else { + stop( "Call donor_table on original result from augsynth() or a permutation-inference summary" ) + } + } + + stopifnot( is.augsynth(augsynth) ) + + trt_index <- which(augsynth$data$trt == 1) + unit_var <- augsynth$unit_var + + tbl = data.frame( + unit = rownames( augsynth$weights ), + weight = as.numeric(augsynth$weights) ) + names(tbl)[[1]] = unit_var + + # If RMPSEs already exist, or flag says to calculate them, then calculate them + if ( include_RMSPE || (!is.null(augsynth$results) && augsynth$results$inf_type %in% c("permutation", "permutation_rstat")) ) { + + if (is.null(augsynth$results) || (!(augsynth$results$inf_type %in% c("permutation", "permutaton_rstat")) ) ) { + augsynth <- add_inference(augsynth, inf_type = 'permutation') + } + RMSPEs <- augsynth$results$permutations$placebo_dist %>% + select(!!unit_var, RMSPE) %>% + distinct() + tbl <- left_join(tbl, RMSPEs, by = unit_var) + } + + if ( !is.null(zap_weights) ) { + tbl <- tbl %>% mutate(weight = ifelse(abs(weight) < zap_weights, 0, weight)) + } + + return(tbl) +} + + +#' Return a new augsynth object with specified donor units removed +#' +#' @param augsynth Augsynth object to be plotted +#' +#' @param drop Drop donor units, based on pre-treatment RMSPE or unit +#' name(s). Default of 20 means drop units with an RMSPE 20x higher +#' than the treated unit. The `drop` parameter can also be a character vector of unit +#' IDs to drop. +#' +#' @export +update_augsynth <- function(augsynth, drop = 20){ + + if (is.null(augsynth$results)){ + inf_type = 'none' + } else { + inf_type <- augsynth$results$inf_type + } + + # run placebo tests if necessary + if (!inf_type %in% c('permutation', 'permutation_rstat')) { + augsynth <- add_inference(augsynth, inf_type = 'permutation') + } + + unit_var <- augsynth$unit_var + # pre-treatment RMSPE among donors + donor_RMSPE <- augsynth$results$permutations$placebo_dist %>% + filter(!!as.name(augsynth$time_var) < augsynth$t_int) %>% + group_by(!!as.name(augsynth$unit_var)) %>% + summarise(RMSPE = sqrt(mean(ATT ^ 2)), .groups = "drop") + # pre-treatment RMSPE for treated unit + trt_RMSPE <- add_inference(augsynth, inf_type = 'permutation')$results$permutations$placebo_dist %>% + filter(!!as.name(augsynth$time_var) < augsynth$t_int) %>% + filter(!!as.name(unit_var) == augsynth$trt_unit) %>% + pull(RMSPE) %>% unique() + + if (is.numeric(drop)) { + keep_units <- donor_RMSPE %>% filter(RMSPE / trt_RMSPE <= drop) %>% pull(!!unit_var) + } else if (is.character(drop)) { + keep_units <- donor_RMSPE %>% filter((!!as.name(unit_var) %in% drop) == FALSE) %>% pull(!!unit_var) %>% unique() + } + keep_units <- c(keep_units, augsynth$trt_unit) + + form <- as.formula(paste(as.character(augsynth$form)[2], as.character(augsynth$form)[1], as.character(augsynth$form)[3])) + new_data <- as_tibble(augsynth$raw_data, .name_repair = 'unique') %>% + filter(!!as.name(unit_var) %in% keep_units) + + new_augsynth <- augsynth(form = form, + unit = !!as.name(augsynth$unit_var), + time = !!as.name(augsynth$time_var), + data = new_data, + progfunc = augsynth$progfunc, + scm = augsynth$scm, + fixedeff = augsynth$fixedeff, + cov_agg = augsynth$cov_agg + ) + + return(new_augsynth) +} diff --git a/R/fit_synth.R b/R/fit_synth.R index 69cddcd..17c006c 100644 --- a/R/fit_synth.R +++ b/R/fit_synth.R @@ -3,12 +3,14 @@ ####################################################### #' Make a V matrix from a vector (or null) +#' +#' @noRd make_V_matrix <- function(t0, V) { if(is.null(V)) { V <- diag(rep(1, t0)) } else if(is.vector(V)) { if(length(V) != t0) { - stop(paste("`V` must be a vector with", t0, "elements or a", t0, + stop(paste("`V` must be a vector with", t0, "elements or a", t0, "x", t0, "matrix")) } V <- diag(V) @@ -18,7 +20,7 @@ make_V_matrix <- function(t0, V) { V <- diag(c(V)) } else if(nrow(V) == t0) { } else { - stop(paste("`V` must be a vector with", t0, "elements or a", t0, + stop(paste("`V` must be a vector with", t0, "elements or a", t0, "x", t0, "matrix")) } @@ -61,7 +63,7 @@ fit_synth_formatted <- function(synth_data, V = NULL) { #' @param V Scaling matrix #' @noRd synth_qp <- function(X1, X0, V) { - + Pmat <- X0 %*% V %*% t(X0) qvec <- - t(X1) %*% V %*% t(X0) @@ -74,7 +76,7 @@ synth_qp <- function(X1, X0, V) { eps_rel = 1e-8, eps_abs = 1e-8) sol <- osqp::solve_osqp(P = Pmat, q = qvec, - A = A, l = l, u = u, + A = A, l = l, u = u, pars = settings) return(sol$x) diff --git a/R/inference.R b/R/inference.R index 0b84216..681451f 100644 --- a/R/inference.R +++ b/R/inference.R @@ -26,7 +26,7 @@ time_jackknife_plus <- function(ascm, alpha = 0.05, conservative = F) { tpost <- ncol(wide_data$y) t_final <- dim(synth_data$Y0plot)[1] - jack_ests <- lapply(1:t0, + jack_ests <- lapply(1:t0, function(tdrop) { # drop unit i new_data <- drop_time_t(wide_data, Z, tdrop) @@ -56,11 +56,11 @@ time_jackknife_plus <- function(ascm, alpha = 0.05, conservative = F) { out <- list() att <- predict(ascm, att = T) - out$att <- c(att, + out$att <- c(att, mean(att[(t0 + 1):t_final])) # held out ATT - out$heldout_att <- c(held_out_errs, - att[(t0 + 1):t_final], + out$heldout_att <- c(held_out_errs, + att[(t0 + 1):t_final], mean(att[(t0 + 1):t_final])) # out$se <- rep(NA, 10 + tpost) @@ -95,7 +95,7 @@ drop_time_t <- function(wide_data, Z, t_drop) { new_wide_data <- list() new_wide_data$trt <- wide_data$trt new_wide_data$X <- wide_data$X[, -t_drop, drop = F] - new_wide_data$y <- cbind(wide_data$X[, t_drop, drop = F], + new_wide_data$y <- cbind(wide_data$X[, t_drop, drop = F], wide_data$y) X0 <- new_wide_data$X[new_wide_data$trt == 0,, drop = F] @@ -113,7 +113,7 @@ drop_time_t <- function(wide_data, Z, t_drop) { return(list(wide_data = new_wide_data, synth_data = new_synth_data, - Z = Z)) + Z = Z)) } #' Conformal inference procedure to compute p-values and point-wise confidence intervals @@ -134,6 +134,7 @@ drop_time_t <- function(wide_data, Z, t_drop) { #' \item{"p_val"}{p-value for test of no post-treatment effect} #' \item{"alpha"}{Level of confidence interval} #' } + conformal_inf <- function(ascm, alpha = 0.05, stat_func = NULL, type = "iid", q = 1, ns = 1000, grid_size = 50) { @@ -177,7 +178,7 @@ conformal_inf <- function(ascm, alpha = 0.05, new_wide_data <- wide_data new_wide_data$X <- cbind(wide_data$X, wide_data$y) new_wide_data$y <- matrix(1, nrow = n, ncol = 1) - null_p <- compute_permute_pval(new_wide_data, ascm, 0, ncol(wide_data$y), + null_p <- compute_permute_pval(new_wide_data, ascm, 0, ncol(wide_data$y), type, q, ns, stat_func) out <- list() att <- predict(ascm, att = T) @@ -273,7 +274,7 @@ conformal_inf_linear <- function(ascm, alpha = 0.05, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return List that contains: #' \itemize{ #' \item{"resids"}{Residuals after enforcing the null} @@ -315,7 +316,7 @@ compute_permute_test_stats <- function(wide_data, ascm, h0, stat_func <- function(x) (sum(abs(x) ^ q) / sqrt(length(x))) ^ (1 / q) } if(type == "iid") { - test_stats <- sapply(1:ns, + test_stats <- sapply(1:ns, function(x) { reorder <- sample(resids) stat_func(reorder[(t0 + 1):tpost]) @@ -328,7 +329,7 @@ compute_permute_test_stats <- function(wide_data, ascm, h0, stat_func(reorder[(t0 + 1):tpost]) }) } - + return(list(resids = resids, test_stats = test_stats, stat_func = stat_func)) @@ -344,7 +345,7 @@ compute_permute_test_stats <- function(wide_data, ascm, h0, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return Computed p-value #' @noRd compute_permute_pval <- function(wide_data, ascm, h0, @@ -366,7 +367,7 @@ compute_permute_pval <- function(wide_data, ascm, h0, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect) #' @noRd compute_permute_ci <- function(wide_data, ascm, grid, @@ -374,9 +375,9 @@ compute_permute_ci <- function(wide_data, ascm, grid, q, ns, stat_func) { # make sure 0 is in the grid grid <- c(grid, 0) - ps <-sapply(grid, + ps <-sapply(grid, function(x) { - compute_permute_pval(wide_data, ascm, x, + compute_permute_pval(wide_data, ascm, x, post_length, type, q, ns, stat_func) }) c(min(grid[ps >= alpha]), max(grid[ps >= alpha]), ps[grid == 0]) @@ -447,7 +448,7 @@ time_jackknife_plus_multiout <- function(ascm_multi, alpha = 0.05, conservative t_final <- t0 + tpost Z <- wide_data$Z - jack_ests <- lapply(1:t0, + jack_ests <- lapply(1:t0, function(tdrop) { # drop unit i new_data_list <- drop_time_t_multiout(data_list, Z, tdrop) @@ -483,15 +484,15 @@ time_jackknife_plus_multiout <- function(ascm_multi, alpha = 0.05, conservative out <- list() att <- predict(ascm_multi, att = T) - out$att <- rbind(att, + out$att <- rbind(att, colMeans(att[(t0 + 1):t_final, , drop = F])) # held out ATT - out$heldout_att <- rbind(t(held_out_errs), - att[(t0 + 1):t_final, , drop = F], + out$heldout_att <- rbind(t(held_out_errs), + att[(t0 + 1):t_final, , drop = F], colMeans(att[(t0 + 1):t_final, , drop = F])) if(conservative) { - qerr <- apply(abs(held_out_errs), 1, + qerr <- apply(abs(held_out_errs), 1, stats::quantile, 1 - alpha, type = 1) out$lb <- rbind(matrix(NA, nrow = t0, ncol = k), t(t(apply(jack_dist_cons, 1:2, min)) - qerr)) @@ -502,8 +503,8 @@ time_jackknife_plus_multiout <- function(ascm_multi, alpha = 0.05, conservative out$lb <- rbind(matrix(NA, nrow = t0, ncol = k), apply(jack_dist_low, 1:2, stats::quantile, alpha, type = 1)) - out$ub <- rbind(matrix(NA, nrow = t0, ncol = k), - apply(jack_dist_high, 1:2, + out$ub <- rbind(matrix(NA, nrow = t0, ncol = k), + apply(jack_dist_high, 1:2, stats::quantile, 1 - alpha, type = 1)) } # shift back to ATT scale @@ -532,7 +533,7 @@ drop_time_t_multiout <- function(data_list, Z, t_drop) { function(x) x[, -t_drop, drop = F]) new_data_list$y <- lapply(1:length(data_list$y), function(k) { - cbind(data_list$X[[k]][, t_drop, drop = F], + cbind(data_list$X[[k]][, t_drop, drop = F], data_list$y[[k]]) }) return(new_data_list) @@ -557,7 +558,7 @@ drop_time_t_multiout <- function(data_list, Z, t_drop) { #' \item{"p_val"}{p-value for test of no post-treatment effect} #' \item{"alpha"}{Level of confidence interval} #' } -conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, +conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, stat_func = NULL, type = "iid", q = 1, ns = 1000, grid_size = 1, lin_h0 = NULL) { @@ -577,7 +578,7 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, post_att <- att[(t0 +1):t_final,, drop = F] post_sd <- apply(post_att, 2, function(x) sqrt(mean(x ^ 2, na.rm = T))) # iterate over post-treatment periods to get pointwise CIs - + vapply(1:tpost, function(j) { # fit using t0 + j as a pre-treatment period and get residuals @@ -589,8 +590,8 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, colnames(data_list$y[[i]])[j]) Xi }) - - + + if(tpost > 1) { new_data_list$y <- lapply(1:k, function(i) { @@ -609,7 +610,7 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, # make a grid around the estimated ATT if(is.null(lin_h0)) { - grid <- lapply(1:k, + grid <- lapply(1:k, function(i) { seq(att[t0 + j, i] - 2 * post_sd[i], att[t0 + j, i] + 2 * post_sd[i], length.out = grid_size) @@ -627,7 +628,7 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, compute_permute_pval_multiout(new_data_list, ascm_multi, numeric(k), 1, type, q, ns, stat_func)) } - + }, matrix(0, ncol = k, nrow=3)) -> cis # # test a null post-treatment effect @@ -646,10 +647,10 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, data_list$y[[i]][, 1, drop = FALSE] }) null_p <- compute_permute_pval_multiout(new_data_list, ascm_multi, - numeric(k), + numeric(k), tpost, type, q, ns, stat_func) if(is.null(lin_h0)) { - grid <- lapply(1:k, + grid <- lapply(1:k, function(i) { seq(min(att[(t0 + 1):tpost, i]) - 4 * post_sd[i], max(att[(t0 + 1):tpost, i]) + 4 * post_sd[i], @@ -698,7 +699,7 @@ conformal_inf_multiout <- function(ascm_multi, alpha = 0.05, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return List that contains: #' \itemize{ #' \item{"resids"}{Residuals after enforcing the null} @@ -739,7 +740,7 @@ compute_permute_test_stats_multiout <- function(data_list, ascm_multi, h0, } } if(type == "iid") { - test_stats <- sapply(1:ns, + test_stats <- sapply(1:ns, function(x) { idxs <- sample(1:nrow(resids)) reorder <- resids[idxs, , drop = F] @@ -756,7 +757,7 @@ compute_permute_test_stats_multiout <- function(data_list, ascm_multi, h0, apply(reorder[(t0 + 1):tpost, , drop = F], 2, stat_func) }) } - + return(list(resids = resids, test_stats = matrix(test_stats, nrow = k), stat_func = stat_func)) @@ -772,7 +773,7 @@ compute_permute_test_stats_multiout <- function(data_list, ascm_multi, h0, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return Computed p-value #' @noRd compute_permute_pval_multiout <- function(data_list, ascm_multi, h0, @@ -805,7 +806,7 @@ compute_permute_pval_multiout <- function(data_list, ascm_multi, h0, #' @param q The norm for the test static `((sum(x ^ q))) ^ (1/q)` #' @param ns Number of resamples for "iid" permutations #' @param stat_func Function to compute test statistic -#' +#' #' @return (lower bound of interval, upper bound of interval, p-value for null of 0 effect) #' @noRd compute_permute_ci_multiout <- function(data_list, ascm_multi, grid, @@ -827,12 +828,12 @@ compute_permute_ci_multiout <- function(data_list, ascm_multi, grid, } ps <- apply(grid, 1, function(x) { - compute_permute_pval_multiout(data_list, ascm_multi, x, + compute_permute_pval_multiout(data_list, ascm_multi, x, post_length, type, q, ns, stat_func) }) - sapply(1:k, - function(i) c(min(grid[ps >= alpha, i]), - max(grid[ps >= alpha, i]), + sapply(1:k, + function(i) c(min(grid[ps >= alpha, i]), + max(grid[ps >= alpha, i]), ps[apply(grid == 0, 1, all)])) } @@ -890,7 +891,7 @@ drop_unit_i_multiout <- function(wide_list, Z, i) { #' Estimate standard errors for single ASCM with the jackknife #' Do this for ridge-augmented synth #' @param ascm Fitted augsynth object -#' +#' #' @return List that contains: #' \itemize{ #' \item{"att"}{Vector of ATT estimates} @@ -1152,7 +1153,7 @@ weighted_bootstrap_multi <- function(multisynth, function(x) mean(x, na.rm=T)) upper_bound <- att - apply(bs_est, c(1,2), function(x) quantile(x, alpha / 2, na.rm = T)) - + lower_bound <- att - apply(bs_est, c(1,2), function(x) quantile(x, 1 - alpha / 2, na.rm = T)) @@ -1164,24 +1165,38 @@ weighted_bootstrap_multi <- function(multisynth, } -#' Bayesian bootstrap + +#' Bootstrap Functions +#' +#' There are several helper functions used for bootstrap inference. +#' Each method returns a list of selection weights of length n, +#' which weights that sum to n. +#' #' @param n Number of units +#' @name bootstrap_methods +#' @rdname bootstrap_methods +NULL + + + +#' @describeIn bootstrap_methods Bayesian bootstrap #' @export rdirichlet_b <- function(n) { Z <- as.numeric(rgamma(n, 1, 1)) return(Z / sum(Z) * n) } -#' Non-parametric bootstrap -#' @param n Number of units +#' @describeIn bootstrap_methods Non-parametric bootstrap #' @export -rmultinom_b <- function(n) as.numeric(rmultinom(1, n, rep(1 / n, n))) +rmultinom_b <- function(n) { + as.numeric(rmultinom(1, n, rep(1 / n, n))) +} -#' Wild bootstrap (Mammen 1993) -#' @param n Number of units + +#' @describeIn bootstrap_methods Wild bootstrap (Mammen 1993) #' @export rwild_b <- function(n) { sample(c(-(sqrt(5) - 1) / 2, (sqrt(5) + 1) / 2 ), n, replace = TRUE, prob = c((sqrt(5) + 1)/ (2 * sqrt(5)), (sqrt(5) - 1) / (2 * sqrt(5)))) -} \ No newline at end of file +} diff --git a/R/make_synth_data.R b/R/make_synth_data.R new file mode 100644 index 0000000..4be0973 --- /dev/null +++ b/R/make_synth_data.R @@ -0,0 +1,133 @@ + +# Utility function to make fake data for testing and whatnot + + + +#' A synthetic data simulator +#' +#' This generates data with time shocks that interact with the unit +#' latent factors. +#' +#' It also gives a unit fixed effect and a time fixed +#' effect. +#' +#' @param n_time Number of time periods +#' @param n_post Number of the time periods that are after tx onset +#' @param n_U Number of latent factors +#' @param N Number of units +#' @param N_tx number of treated units, out of total N +#' @param sd_time How much the time varying shocks matter. +#' @param sd_unit_fe How much the units have different shifts +#' @param sd_time_fe How much the time intercept shifts vary +#' @param tx_shift Add this shift to the distribution of the +#' treatment latent factors to create differences in tx and co +#' groups. +#' +#' @noRd +make_synth_data = function(n_U, N, n_time, N_tx = 1, n_post = 3, + long_form = FALSE, + tx_impact = 1:n_post, + sd_time = 0.3, + sd_unit_fe = 1, + sd_time_fe = 0.2, + sd_e = 1, + tx_shift = 0.5 ) { + + stopifnot( N_tx < N ) + stopifnot( n_post < n_time ) + + ## (1) Make latent factors for units + + # Make correlation structure for latent factors + Sigma = matrix(0.15, nrow = n_U, ncol = n_U) + diag(Sigma) = 1 + #solve(Sigma) + U = MASS::mvrnorm(N, mu = rep(1, n_U), Sigma = Sigma) + #U = abs( U ) + U = cbind( U, 1 ) + U[,1] = sd_unit_fe * U[,1] + #U + #summary(U) + U[1:N_tx,1:n_U] = U[1:N_tx,1:n_U] + tx_shift + + + # (2) Make the time varying component, with the first row being an + # intercept + shocks = matrix(rnorm(n_U * n_time, sd=sd_time), nrow = n_U) + #shocks = abs( shocks ) + shocks = rbind( 1, shocks ) + #shocks[1, ] = 2 * sort(shocks[1, ]) + #shocks[2, ] = sort(shocks[2, ]) + #shocks[2, ] = 0.5 * rev(sort(shocks[2, ])) + #browser() + + # Make a time drift by having the fixed effect time shock increase over time. + shocks[n_U+1,] = sort( shocks[n_U+1,] ) + shocks[n_U+1,] = sd_time_fe * shocks[n_U+1,] / sd_time + # Alt way to create time drift + # shocks = shocks + rep( 1:n_time, each=n_U ) / n_time + + #shocks = shocks + (1:nrow(shocks))/nrow(shocks) + #qplot(1:n_time, shocks[1, ]) + #dim(U) + #dim(shocks) + + # Outcome is latent factors times time shocks, with extra noise + # added. + Y = U %*% shocks + Y = Y + rnorm( length(Y), sd = sd_e ) + #dim(Y) + + dat = as.data.frame(Y) + colnames(dat) = 1:n_time + dat$ID = 1:N + dat$Tx = 0 + dat$Tx[1:N_tx] = 1 + + + # Add in some covariates predictive of latent U (skipping + # intercept) + X = U[,1:n_U, drop=FALSE] + for (i in 1:n_U) { + X[, i] = X[, i] + rnorm(nrow(X)) + } + X = round( X, digits = 1 ) + + #head(dat) + colnames(X) = paste0("X", 1:ncol(X)) + X = as.data.frame(X) + dat = bind_cols(dat, X) + + # Add in a treatment impact! + tx_index = (n_time-n_post+1):n_time + imp = matrix( tx_impact, + ncol=n_post, + nrow=N_tx, + byrow = TRUE ) + dat[ dat$Tx == 1, tx_index ] = dat[ dat$Tx == 1, tx_index ] + imp + + + # Final packaging of the data + if ( long_form ) { + + ldat = pivot_longer( + dat, + cols = all_of(1:n_time), + names_to = "time", + values_to = "Y" + ) + + # Make Y look nice. + ldat$Y = round( ldat$Y, digits = 1 ) #round( ldat$Y / 5, digits = 1 ) * 5 + + ldat$time = as.numeric(ldat$time) + ldat = mutate( ldat, + ever_Tx = Tx, + Tx = ifelse( ever_Tx & time > n_time - n_post, 1, 0 ) ) + ldat + } else { + dat + } +} + + diff --git a/R/multi_outcomes.R b/R/multi_outcomes.R index 89663d5..cb2fdbb 100644 --- a/R/multi_outcomes.R +++ b/R/multi_outcomes.R @@ -8,7 +8,7 @@ #' Ridge=Ridge regression (allows for standard errors), #' None=No outcome model, #' @param scm Whether the SCM weighting function is used -#' @param fixedeff Whether to include a unit fixed effect, default F +#' @param fixedeff Whether to include a unit fixed effect, default F #' @param cov_agg Covariate aggregation functions, if NULL then use mean with NAs omitted #' @param combine_method How to combine outcomes: `concat` concatenates outcomes and `avg` averages them, default: 'avg' #' @param ... optional arguments for outcome model @@ -31,10 +31,15 @@ augsynth_multiout <- function(form, unit, time, t_int, data, ...) { call_name <- match.call() + # if a user sets inf_type at time of model fit, return a warning + if("inf_type" %in% names(call_name)) { + warning("`inf_type` is not an argument for augsynth_multiout, so it is ignored") + } + form <- Formula::Formula(form) unit <- enquo(unit) time <- enquo(time) - + ## format data outcome <- terms(formula(form, rhs=1))[[2]] trt <- terms(formula(form, rhs=1))[[3]] @@ -43,9 +48,9 @@ augsynth_multiout <- function(form, unit, time, t_int, data, outcomes <- sapply(outcomes_str, quo) # get outcomes as a list wide_list <- format_data_multi(outcomes, trt, unit, time, t_int, data) - - + + ## add covariates if(length(form)[2] == 2) { @@ -69,11 +74,11 @@ augsynth_multiout <- function(form, unit, time, t_int, data, # add some extra data augsynth$data$time <- data %>% distinct(!!time) %>% pull(!!time) augsynth$call <- call_name - augsynth$t_int <- t_int + augsynth$t_int <- t_int augsynth$combine_method <- combine_method treated_units <- data %>% filter(!!trt == 1) %>% distinct(!!unit) %>% pull(!!unit) - control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% + control_units <- data %>% filter(!(!!unit %in% treated_units)) %>% distinct(!!unit) %>% pull(!!unit) augsynth$weights <- matrix(augsynth$weights) rownames(augsynth$weights) <- control_units @@ -92,7 +97,7 @@ augsynth_multiout <- function(form, unit, time, t_int, data, #' @param ... Extra args for outcome model #' @noRd fit_augsynth_multiout_internal <- function(wide_list, combine_method, Z, - progfunc, scm, fixedeff, + progfunc, scm, fixedeff, outcomes_str, ...) { @@ -176,6 +181,7 @@ combine_outcomes <- function(wide_list, combine_method, fixedeff, # combine outcomes if(combine_method == "concat") { + # center X and scale by overall variance for outcome # X <- lapply(wide_list$X, function(x) t(t(x) - colMeans(x)) / sd(x)) wide_bal <- list(X = do.call(cbind, lapply(wide_list$X, function(x) t(na.omit(t(x))))), @@ -195,8 +201,8 @@ combine_outcomes <- function(wide_list, combine_method, fixedeff, # trt = wide_list$trt) # # first get the standard deviations of the outcomes to put on the same scale - # sds <- do.call(c, - # lapply(wide_list$X, + # sds <- do.call(c, + # lapply(wide_list$X, # function(x) rep((sqrt(ncol(x)) * sd(x, na.rm=T)), ncol(x)))) # # do an SVD on centered and scaled outcomes @@ -273,33 +279,33 @@ predict.augsynth_multiout <- function(object, ...) { # separate out by outcome n_outs <- length(object$data_list$X) - max_t <- max(sapply(1:n_outs, + max_t <- max(sapply(1:n_outs, function(k) ncol(object$data_list$X[[k]]) + ncol(object$data_list$y[[k]]))) - pred_reshape <- matrix(NA, ncol = n_outs, + pred_reshape <- matrix(NA, ncol = n_outs, nrow = max_t) - colnames <- lapply(1:n_outs, - function(k) colnames(cbind(object$data_list$X[[k]], + colnames <- lapply(1:n_outs, + function(k) colnames(cbind(object$data_list$X[[k]], object$data_list$y[[k]]))) rownames(pred_reshape) <- colnames[[which.max(sapply(colnames, length))]] colnames(pred_reshape) <- object$outcomes # get outcome names for predictions - pre_outs <- do.call(c, - sapply(1:n_outs, + pre_outs <- do.call(c, + sapply(1:n_outs, function(j) { rep(object$outcomes[j], ncol(object$data_list$X[[j]])) }, simplify = FALSE)) - + post_outs <- do.call(c, - sapply(1:n_outs, + sapply(1:n_outs, function(j) { rep(object$outcomes[j], ncol(object$data_list$y[[j]])) }, simplify = FALSE)) # print(pred) # print(cbind(names(pred), c(pre_outs, post_outs))) - + pred_reshape[cbind(names(pred), c(pre_outs, post_outs))] <- pred return(pred_reshape) } @@ -330,8 +336,10 @@ print.augsynth_multiout <- function(x, ...) { #' @param grid_size Grid to compute prediction intervals over, default is 1 and only p-values are computed #' @param ... Optional arguments, including \itemize{\item{"se"}{Whether to plot standard error}} #' @export + summary.augsynth_multiout <- function(object, inf = T, inf_type = "conformal", grid_size = 1, ...) { - + + summ <- list() @@ -430,11 +438,7 @@ summary.augsynth_multiout <- function(object, inf = T, inf_type = "conformal", g if(grid_size == 1) { average_att <- average_att %>% mutate(lower_bound = NA, upper_bound = NA) } - # } - # } else { - # average_att <- gather(att_avg, Outcome, Estimate) - # } - + } else { att_est <- predict(object, att = T) @@ -489,9 +493,9 @@ summary.augsynth_multiout <- function(object, inf = T, inf_type = "conformal", g summ$bias_est <- NA } - - + + class(summ) <- "summary.augsynth_multiout" return(summ) } @@ -504,10 +508,10 @@ summary.augsynth_multiout <- function(object, inf = T, inf_type = "conformal", g print.summary.augsynth_multiout <- function(x, ...) { ## straight from lm cat("\nCall:\n", paste(deparse(x$call), sep="\n", collapse="\n"), "\n\n", sep="") - + att_est <- x$att$Estimate ## get pre-treatment fit by outcome - imbal <- x$att %>% + imbal <- x$att %>% filter(Time < x$t_int) %>% group_by(Outcome) %>% summarise(Pre.RMSE = sqrt(mean(Estimate ^ 2, na.rm = TRUE))) @@ -530,7 +534,7 @@ print.summary.augsynth_multiout <- function(x, ...) { #' @param inf Boolean, whether to plot uncertainty intervals, default TRUE #' @param plt_avg Boolean, whether to plot the average of the outcomes, default FALSE #' @param ... Optional arguments for summary function -#' +#' #' @export plot.augsynth_multiout <- function(x, inf = T, plt_avg = F, ...) { plot(summary(x, ...), inf = inf, plt_avg = plt_avg) @@ -540,7 +544,7 @@ plot.augsynth_multiout <- function(x, inf = T, plt_avg = F, ...) { #' @param x summary.augsynth_multiout object #' @param inf Boolean, whether to plot uncertainty intervals, default TRUE #' @param plt_avg Boolean, whether to plot the average of the outcomes, default FALSE -#' +#' #' @export plot.summary.augsynth_multiout <- function(x, inf = F, plt_avg = F, ...) { if(plt_avg) { @@ -548,7 +552,7 @@ plot.summary.augsynth_multiout <- function(x, inf = F, plt_avg = F, ...) { ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate)) } else { p <- x$att %>% - filter(Outcome != "Average") %>% + filter(Outcome != "Average") %>% ggplot2::ggplot(ggplot2::aes(x=Time, y=Estimate)) } if(inf) { @@ -569,4 +573,4 @@ plot.summary.augsynth_multiout <- function(x, inf = F, plt_avg = F, ...) { ggplot2::facet_wrap(~ Outcome, scales = "free_y") + ggplot2::theme_bw() -} \ No newline at end of file +} diff --git a/R/multisynth_class.R b/R/multisynth_class.R index b9c1d04..837ed3b 100644 --- a/R/multisynth_class.R +++ b/R/multisynth_class.R @@ -1,3 +1,5 @@ + + ################################################################################ ## Fitting, plotting, summarizing staggered synth ################################################################################ @@ -29,7 +31,7 @@ #' @param eps_rel Relative error tolerance for osqp #' @param verbose Whether to print logs for osqp #' @param ... Extra arguments -#' +#' #' @return multisynth object that contains: #' \itemize{ #' \item{"weights"}{weights matrix where each column is a set of weights for a treated unit} @@ -67,6 +69,11 @@ multisynth <- function(form, unit, time, data, unit <- enquo(unit) time <- enquo(time) + # if a user sets inf_type at time of model fit, return a warning + if("inf_type" %in% names(call_name)) { + warning("`inf_type` is not an argument for multisynth, so it is ignored") + } + ## format data outcome <- terms(formula(form, rhs=1))[[2]] trt <- terms(formula(form, rhs=1))[[3]] @@ -142,10 +149,10 @@ multisynth <- function(form, unit, time, data, scm = scm, time_cohort = time_cohort, time_w = F, lambda_t = 0, fit_resids = TRUE, eps_abs = eps_abs, - eps_rel = eps_rel, verbose = verbose, long_df = long_df, + eps_rel = eps_rel, verbose = verbose, long_df = long_df, how_match = how_match, ...) - - + + units <- data %>% arrange(!!unit) %>% distinct(!!unit) %>% pull(!!unit) rownames(msynth$weights) <- units @@ -166,7 +173,7 @@ multisynth <- function(form, unit, time, data, V = V, time_cohort = time_cohort, donors = msynth$donors, - eps_rel = eps_rel, + eps_rel = eps_rel, eps_abs = eps_abs, verbose = verbose) ## scaled global balance @@ -179,6 +186,10 @@ multisynth <- function(form, unit, time, data, } msynth$call <- call_name + msynth$trt_unit <- msynth$data$units[ msynth$data$trt < Inf ] + msynth$time_var <- quo_name(time) + msynth$unit_var <- quo_name(unit) + msynth$form <- form return(msynth) @@ -211,11 +222,11 @@ multisynth_formatted <- function(wide, relative=T, n_leads, n_lags, nu, lambda, V, force, n_factors, - scm, time_cohort, + scm, time_cohort, time_w, lambda_t, fit_resids, eps_abs, eps_rel, - verbose, long_df, + verbose, long_df, how_match, ...) { ## average together treatment groups ## grps <- unique(wide$trt) %>% sort() @@ -246,23 +257,23 @@ multisynth_formatted <- function(wide, relative=T, n_leads, n_lags, n_factors <- ncol(params$factor) ## get residuals from outcome model residuals <- cbind(wide$X, wide$y) - y0hat - + } else if (n_factors != 0) { ## if number of factors is provided don't do CV out <- fit_gsynth_multi(long_df, cbind(wide$X, wide$y), wide$trt, r=n_factors, CV=0, force=force) y0hat <- out$y0hat - params <- out$params - + params <- out$params + ## get residuals from outcome model residuals <- cbind(wide$X, wide$y) - y0hat } else if(force == 0 & n_factors == 0) { - # if no fixed effects or factors, just take out + # if no fixed effects or factors, just take out # control averages at each time point # time fixed effects from pure controls pure_ctrl <- cbind(wide$X, wide$y)[!is.finite(wide$trt), , drop = F] y0hat <- matrix(colMeans(pure_ctrl, na.rm = TRUE), - nrow = nrow(wide$X), ncol = ncol(pure_ctrl), + nrow = nrow(wide$X), ncol = ncol(pure_ctrl), byrow = T) residuals <- cbind(wide$X, wide$y) - y0hat params <- NULL @@ -297,13 +308,13 @@ multisynth_formatted <- function(wide, relative=T, n_leads, n_lags, bal_mat <- wide$X - ctrl_avg bal_mat <- wide$X } - + if(scm) { # get eligible set of donor units based on covariates donors <- get_donors(wide$X, wide$y, wide$trt, - wide$Z[, colnames(wide$Z) %in% + wide$Z[, colnames(wide$Z) %in% wide$match_covariates, drop = F], time_cohort, n_lags, n_leads, how = how_match, exact_covariates = wide$exact_covariates, ...) @@ -386,9 +397,9 @@ multisynth_formatted <- function(wide, relative=T, n_leads, n_lags, msynth$time_w <- time_w msynth$lambda_t <- lambda_t msynth$fit_resids <- fit_resids - msynth$extra_pars <- c(list(eps_abs = eps_abs, - eps_rel = eps_rel, - verbose = verbose), + msynth$extra_pars <- c(list(eps_abs = eps_abs, + eps_rel = eps_rel, + verbose = verbose), list(...)) msynth$long_df <- long_df @@ -417,7 +428,7 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N multisynth <- object relative <- T - + time_cohort <- multisynth$time_cohort if(is.null(relative)) { relative <- multisynth$relative @@ -436,19 +447,19 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N } if(time_cohort) { - which_t <- lapply(grps, + which_t <- lapply(grps, function(tj) (1:n)[multisynth$data$trt == tj]) mask <- unique(multisynth$data$mask) } else { which_t <- (1:n)[is.finite(multisynth$data$trt)] mask <- multisynth$data$mask } - + n1 <- sapply(1:J, function(j) length(which_t[[j]])) fullmask <- cbind(mask, matrix(0, nrow = J, ncol = (ttot - d))) - + ## estimate the post-treatment values to get att estimates mu1hat <- vapply(1:J, @@ -502,11 +513,11 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N c(vec, rep(NA, total_len - length(vec)), mean(mu0hat[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])) - + }, numeric(total_len +1 )) - + tauhat <- vapply(1:J, function(j) { vec <- c(rep(NA, d-grps[j]), @@ -536,11 +547,11 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N c(vec, rep(NA, total_len - length(vec)), mean(att_weight[(grps[j]+1):(min(grps[j] + n_leads, ttot)), j])) - + }, numeric(total_len +1 )) - + ## get the overall average estimate avg <- apply(mu0hat, 1, function(z) sum(n1 * z, na.rm=T) / sum(n1 * !is.na(z))) avg <- sapply(1:nrow(mu0hat), @@ -557,7 +568,7 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N sum(n1 * (!is.na(tauhat[k,])) * att_weight_new[k, ], na.rm = T) }) tauhat <- cbind(avg, tauhat) - + } else { ## remove all estimates for t > T_j + n_leads @@ -571,7 +582,7 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N rep(NA, max(0, ttot-(grps[j] + n_leads)))), numeric(ttot)) -> tauhat - + ## only average currently treated units avg1 <- rowSums(t(fullmask) * mu0hat * n1) / rowSums(t(fullmask) * n1) @@ -590,7 +601,7 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N replace_na(avg2,0) * apply(1 - fullmask, 2, max) cbind(avg, tauhat) -> tauhat } - + if(att) { return(tauhat) @@ -606,9 +617,9 @@ predict.multisynth <- function(object, att = F, att_weight = NULL, bs_weight = N #' @export print.multisynth <- function(x, att_weight = NULL, ...) { multisynth <- x - + ## straight from lm - cat("\nCall:\n", paste(deparse(multisynth$call), + cat("\nCall:\n", paste(deparse(multisynth$call), sep="\n", collapse="\n"), "\n\n", sep="") # print att estimates @@ -636,7 +647,7 @@ print.multisynth <- function(x, att_weight = NULL, ...) { #' @param ... Optional arguments #' @export plot.multisynth <- function(x, inf_type = "bootstrap", inf = T, - levels = NULL, label = T, + levels = NULL, label = T, weights = FALSE, ...) { if(weights) { @@ -651,13 +662,13 @@ plot.multisynth <- function(x, inf_type = "bootstrap", inf = T, # plotting the weights weights %>% tidyr::pivot_longer(-unit, names_to = "trt_unit", values_to = "weight") %>% - ggplot2::ggplot(aes(x = trt_unit, y = unit, fill = round(weight, 3))) + + ggplot2::ggplot(aes(x = trt_unit, y = unit, fill = round(weight, 3))) + ggplot2::geom_tile(color = "white", size=.5) + ggplot2::scale_fill_gradient(low = "white", high = "black", limits=c(-0.01,1.01)) + ggplot2::guides(fill = "none") + ggplot2::xlab("Treated Unit") + ggplot2::ylab("Donor Unit") + - ggplot2::theme_bw() + + ggplot2::theme_bw() + ggplot2::theme(axis.ticks.x = ggplot2::element_blank(), axis.ticks.y = ggplot2::element_blank()) } @@ -679,7 +690,7 @@ plot.multisynth <- function(x, inf_type = "bootstrap", inf = T, #' \item{jackknife}{Jackknife} #' } #' @param ... Optional arguments -#' +#' #' @return summary.multisynth object that contains: #' \itemize{ #' \item{"att"}{Dataframe with ATT estimates, standard errors for each treated unit} @@ -693,7 +704,7 @@ plot.multisynth <- function(x, inf_type = "bootstrap", inf = T, summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL, ...) { multisynth <- object - + relative <- T n_leads <- multisynth$n_leads @@ -710,17 +721,17 @@ summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL grps <- trt[is.finite(trt)] which_t <- (1:n)[is.finite(trt)] } - + # grps <- unique(multisynth$data$trt) %>% sort() J <- length(grps) - + # which_t <- (1:n)[is.finite(multisynth$data$trt)] times <- multisynth$data$time - + summ <- list() ## post treatment estimate for each group and overall # att <- predict(multisynth, relative, att=T) - + if(inf_type == "jackknife") { attse <- jackknife_se_multi(multisynth, relative, att_weight = att_weight, ...) } else if(inf_type == "bootstrap") { @@ -736,16 +747,16 @@ summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL upper_bound = matrix(NA, nrow(att), ncol(att)), lower_bound = matrix(NA, nrow(att), ncol(att))) } - + if(relative) { att <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA), attse$att)) if(time_cohort) { - col_names <- c("Time", "Average", + col_names <- c("Time", "Average", as.character(times[grps + 1])) } else { - col_names <- c("Time", "Average", + col_names <- c("Time", "Average", as.character(multisynth$data$units[which_t])) } names(att) <- col_names @@ -754,7 +765,7 @@ summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL mutate(Time=Time-1) -> att se <- data.frame(cbind(c(-(d-1):min(n_leads, ttot-min(grps)), NA), - attse$se)) + attse$se)) names(se) <- col_names se %>% gather(Level, Std.Error, -Time) %>% rename("Time"=Time) %>% @@ -775,11 +786,11 @@ summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL } else { att <- data.frame(cbind(times, attse$att)) - names(att) <- c("Time", "Average", times[grps[1:J]]) + names(att) <- c("Time", "Average", times[grps[1:J]]) att %>% gather(Level, Estimate, -Time) -> att se <- data.frame(cbind(times, attse$se)) - names(se) <- c("Time", "Average", times[grps[1:J]]) + names(se) <- c("Time", "Average", times[grps[1:J]]) se %>% gather(Level, Std.Error, -Time) -> se } @@ -812,12 +823,12 @@ summary.multisynth <- function(object, inf_type = "bootstrap", att_weight = NULL print.summary.multisynth <- function(x, level = "Average", ...) { summ <- x - + ## straight from lm cat("\nCall:\n", paste(deparse(summ$call), sep="\n", collapse="\n"), "\n\n", sep="") first_lvl <- summ$att %>% filter(Level != "Average") %>% pull(Level) %>% min() - + ## get ATT estimates for treatment level, post treatment if(summ$relative) { summ$att %>% @@ -840,18 +851,18 @@ print.summary.multisynth <- function(x, level = "Average", ...) { pull(Std.Error) %>% round(3) %>% format(nsmall=3), ")\n\n", sep="")) - + cat(paste("Global L2 Imbalance: ", format(round(summ$global_l2,3), nsmall=3), "\n", "Scaled Global L2 Imbalance: ", format(round(summ$scaled_global_l2,3), nsmall=3), "\n", - "Percent improvement from uniform global weights: ", + "Percent improvement from uniform global weights: ", format(round(1-summ$scaled_global_l2,3)*100), "\n\n", "Individual L2 Imbalance: ", format(round(summ$ind_l2,3), nsmall=3), "\n", - "Scaled Individual L2 Imbalance: ", + "Scaled Individual L2 Imbalance: ", format(round(summ$scaled_ind_l2,3), nsmall=3), "\n", - "Percent improvement from uniform individual weights: ", + "Percent improvement from uniform individual weights: ", format(round(1-summ$scaled_ind_l2,3)*100), "\t", "\n\n", sep="")) @@ -863,7 +874,7 @@ print.summary.multisynth <- function(x, level = "Average", ...) { #' Plot function for summary function for multisynth #' @importFrom ggplot2 aes -#' +#' #' @param x summary object #' @param inf Whether to plot confidence intervals #' @param levels Which units/groups to plot, default is every group @@ -887,19 +898,19 @@ plot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T, # plotting the weights weights %>% tidyr::pivot_longer(-unit, names_to = "trt_unit", values_to = "weight") %>% - ggplot2::ggplot(aes(x = trt_unit, y = unit, fill = round(weight, 3))) + + ggplot2::ggplot(aes(x = trt_unit, y = unit, fill = round(weight, 3))) + ggplot2::geom_tile(color = "white", size=.5) + ggplot2::scale_fill_gradient(low = "white", high = "black", limits=c(-0.01,1.01)) + ggplot2::guides(fill = "none") + ggplot2::xlab("Treated Unit") + ggplot2::ylab("Donor Unit") + - ggplot2::theme_bw() + + ggplot2::theme_bw() + ggplot2::theme(axis.ticks.x = ggplot2::element_blank(), axis.ticks.y = ggplot2::element_blank()) } summ <- x - + ## get the last time period for each level summ$att %>% filter(!is.na(Estimate), @@ -921,11 +932,11 @@ plot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T, alpha = is_avg)) + ggplot2::geom_line(size = 1) + ggplot2::geom_point(size = 1) -> p - + if(label) { p <- p + ggrepel::geom_label_repel(ggplot2::aes(label = label), nudge_x = 1, na.rm = T) - } + } p <- p + ggplot2::geom_hline(yintercept = 0, lty = 2) if(summ$relative) { @@ -954,10 +965,10 @@ plot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T, ggplot2::aes(ymin=lower_bound, ymax=upper_bound), alpha = alph, color=clr, - data = summ$att %>% + data = summ$att %>% filter(Level == "Average", Time >= 0)) - + } else { p <- p + error_plt( ggplot2::aes(ymin=lower_bound, @@ -969,8 +980,58 @@ plot.summary.multisynth <- function(x, inf = T, levels = NULL, label = T, p <- p + ggplot2::scale_alpha_manual(values=c(1, 0.5)) + ggplot2::scale_color_manual(values=c("#333333", "#818181")) + - ggplot2::guides(alpha=F, color=F) + + ggplot2::guides(alpha=F, color=F) + ggplot2::theme_bw() return(p) } + + + + +#' +#' @return dim: Dimension of data as pair of (# units, # time points). +#' +#' @rdname multisynth_class +#' @export +#' +dim.multisynth <- function(x, ... ) { + n_unit = length( unique( x$long_df[[ x$unit_var ]] ) ) + n_time = x$n_leads + x$n_lags + return( c( n_unit, n_time ) ) +} + + + +#' @title Number of units in multisynth +#' +#' @return Single number (of unique units). +#' +#' @rdname multisynth_class +#' +#' @export +#' +n_unit.multisynth <- function(x, ... ) { + dim.multisynth(x)[[1]] +} + +#' @title Number of time points in multisynth +#' +#' @rdname multisynth_class +#' +#' @return Single number (of unique time points). +#' @export +n_time.multisynth <- function(x, ... ) { + dim.multisynth(x)[[2]] +} + + +#' +#' @rdname multisynth_class +#' +#' @return Number of treated units (always 1 for multisynth) +#' @export +n_treated.multisynth <- function(x, ... ) { + return( length( x$trt_unit ) ) +} + diff --git a/R/permutation.R b/R/permutation.R new file mode 100644 index 0000000..12ef42d --- /dev/null +++ b/R/permutation.R @@ -0,0 +1,481 @@ + + +#' Run augsynth letting each unit be the treated unit. This +#' implements the abadie placebo test (e.g., from tobacco paper) but +#' with the augsynth package. +#' +#' @param tx.id Name/ID of treatment group (not stored in ascm) +#' +#' @return Dataframe With each row corresponding to a unit (including +#' the original treated unit) and each column corresponding to a +#' time point. Each entry is the estimated impact for that unit at +#' that time, after fitting augsynth on that unit as the tx unit and +#' the other units as possible controls. +#' +#' @noRd +get_placebo_gaps = function( ascm, att = TRUE ) { + + tx_id <- ascm$trt_unit + wide_data <- ascm$data + synth_data <- ascm$data$synth_data + Z <- wide_data$Z + control_ids = which( wide_data$trt == 0 ) + all_ids = 1:length(wide_data$trt) + t_final = length( wide_data$time ) + + ests <- vapply(all_ids, function(i) { + new_data <- swap_treat_unit(wide_data, Z, i) + new_ascm <- do.call(augsynth:::fit_augsynth_internal, + c(list(wide = new_data$wide, + synth_data = new_data$synth_data, + Z = new_data$Z, + progfunc = ascm$progfunc, scm = ascm$scm, + fixedeff = ascm$fixedeff), + ascm$extra_args)) + est <- predict(new_ascm, att = att ) + est + }, numeric(t_final) ) + + # Verify we recover the original treated unit by seeing if the + # estimates match our original fit model's estimates + dim( ests ) + pds = as.numeric( predict( ascm, att = att ) ) + pds + if( !all( round( ests[ , which( wide_data$trt == 1 ) ] - pds, digits=4 ) == 0 ) ) { + stop( "Two versions of estimated impacts do not correspond. Serious error. Please contact package maintainers." ) + } + + + + ests = as.data.frame( t( ests ) ) + dim( ests ) + ests$ID = tx_id + ests$ID[ control_ids ] = rownames( ascm$weights ) + + ests %>% dplyr::select( ID, everything() ) +} + +#### Generate donor unit from fit synth model #### + +#' Take inner data and change which unit is marked as 'treated' +#' +#' @noRd +swap_treat_unit = function (wide_data, Z, i) { + wide_data$trt <- rep( 0, length( wide_data$trt ) ) + wide_data$trt[i] = 1 + + X0 <- wide_data$X[wide_data$trt == 0, , drop = F] + x1 <- matrix( wide_data$X[wide_data$trt == 1, , drop = F], ncol = 1 ) + y0 <- wide_data$y[wide_data$trt == 0, , drop = F] + y1 <- matrix( wide_data$y[wide_data$trt == 1, , drop = F], ncol = 1 ) + new_synth_data <- list() + new_synth_data$Z0 <- t(X0) + new_synth_data$X0 <- t(X0) + new_synth_data$Z1 <- x1 + new_synth_data$X1 <- x1 + + wide_data = wide_data[ c( "trt", "X", "y" ) ] + + return(list(wide_data = wide_data, + synth_data = new_synth_data, + Z = Z)) +} + + + +#' Calculate MDES +#' +#' Use our method for Calculating SEs, p-values, and MDEs by looking +#' at the distribution of impacts across the donor units. +#' +#' @param lest Entity by year estimate estimates for all treatment and control entities +#' (original tx and all the comparisons as pseudo-treated). Columns +#' of entity ID (tx_col), treatment flag (trt), time period (time_col), +#' and estimated tx-synthetic control difference (ATT). +#' @param treat_year The time period in which the treatment was introduced. +#' @param tx_col The name to the column containing the treatment variable. +#' @param time_col The name of the column containing the time variable. +#' @return List of three dataframes, one of MDES, one of RMSPEs, and +#' one of info on the r-statistics (the ATTs divided by the RMSPEs). +#' +#' @noRd +calculate_MDES_table = function( lest, treat_year, tx_col, time_col ) { + + + stopifnot("Entity by year estimates (`lest`) are missing one of the following necessary features: `tx_col`, 'trt', `time_col`, or 'ATT'" = all( c( tx_col, "trt", time_col, "ATT" ) %in% + names( lest ) ) ) + + RMSPEs = calculate_RMSPE( lest, treat_year, tx_col, time_col ) + + # merge calculated pre-intervention RMSPE for each county to data set + # of gaps for each year for each county + + # Divide that year's gap by pre-intervention RMSPE. + rstatcalc <- full_join( lest, RMSPEs, by = tx_col ) %>% + mutate( rstat=ATT/RMSPE ) + + # Calculate the permutation p-values (permuting the r statistics) + pvalues = rstatcalc %>% + group_by( !!as.name(time_col) ) %>% + summarise( p_rstat = mean( abs( rstat[trt == 1] ) <= abs( rstat ) ), + SE_rstat = sd( rstat[ trt != 1 ] ) * RMSPE[ trt == 1 ], + p_gap = mean( abs( ATT[ trt == 1 ] ) <= abs( ATT ) ), + SE_gap = sd( ATT[ trt != 1 ] ), + .groups = "drop" ) + + # attach R statistic for the treated tract. + rstattract <- rstatcalc %>% + filter(trt==1) %>% + dplyr::select( -!!as.name(tx_col), -trt ) + + # Make table of main results with permutation p-values and add + # indicator of which years are treated and which not. + main_results <- rstattract %>% + full_join( pvalues, by=time_col ) %>% + mutate( tx = ifelse( !!as.name(time_col) >= treat_year, 1, 0 ) ) + + + # Clean up column ordering and make look nice + main_results = main_results %>% + relocate( !!as.name(time_col), tx, Yobs ) %>% + dplyr::select( -RMSPE ) + + # Bundle and return all results + list( MDES_table = main_results, + RMSPEs = RMSPEs, + rstatcalc=rstatcalc ) +} + + +#' Calculate RMSPE for all units on lagged outcomes +#' +#' This calculates the RMSPE of the difference between estimated and +#' observed outcome, averaged across the pre-treatment years, for each +#' county +#' +#' @param lest Entity by year estimate estimates for all treatment and control entities +#' (original tx and all the comparisons as pseudo-treated). Columns +#' of entity ID (tx_col), treatment flag (trt), time period (time_col), +#' and estimated tx-synthetic control difference (ATT). +#' @param treat_year The time period in which the treatment was introduced. +#' @param tx_col The name to the column containing the treatment variable. +#' @param time_col The name of the column containing the time variable. +#' +#' @noRd +calculate_RMSPE = function( lest, treat_year, tx_col, time_col ) { + stopifnot( !is.null( lest$ATT ) ) + + RMSPEs = lest %>% filter( !!as.name(time_col) < treat_year ) %>% + group_by( !!as.name(tx_col) ) %>% + summarise( RMSPE = sqrt( mean( ATT^2 ) ), .groups = "drop" ) + + return(RMSPEs) +} + +#' Estimate robust SE +#' +#' Given a list of impact estimates (all assumed to be estimating a +#' target impact of 0), calculate the modified standard deviation by +#' looking an in inter-quartile range rather than the simple standard +#' deviation. This reduces the effect of small numbers of extreme +#' outliers, and focuses on central variation. This then gets +#' converted to a standard error (standard deviation of the results, +#' in effect.) +#' +#' @param placebo_estimates The list of impact estimates to use for +#' estimating the SE. +#' @param k Number of obs to drop from each tail (so 2k obs dropped +#' total) +#' @param beta Proportion of obs to drop from each tail. E.g., 95\% to +#' drop 5\% most extreme. Only one of beta or k can be non-null. +#' @param round_beta TRUE means adjust beta to the nearest integer +#' number of units to drop (the Howard method). FALSE means +#' interpolate quantiles using R's quantile function. +#' +#' @noRd +estimate_robust_SE = function( placebo_estimates, k = NULL, beta = NULL, + round_beta = FALSE ) { + + stopifnot("Either `k` or `beta` must be NULL." = is.null(k) != is.null( beta ) ) + + n = length( placebo_estimates ) + + if ( is.null( beta ) ) { + #alpha = (2*k-1) / (2*n) # Method A + alpha = k/( n-1 ) # Method B + beta = 1 - 2*alpha + q = sort( placebo_estimates )[ c(1+k, n - k) ] + } else { + alpha = (1 - beta)/2 + k = alpha * (n-1) # QUESTION: Why n-1 here???? + if ( round_beta ) { + k = pmax( 1, round( k ) ) + alpha = k/(n-1) + beta = 1 - 2*alpha + q = sort( placebo_estimates )[ c(1+k, n - k) ] + } else { + q = quantile( placebo_estimates, probs = c( alpha, 1 - alpha ) ) + } + } + + del = as.numeric( diff( q ) ) + z = -qnorm( alpha ) + SE_imp = del / (2*z) + + res = data.frame( beta = beta, k = k, + q_low = q[[1]], q_high = q[[2]], + range = del, z = z, SE_imp = SE_imp ) + res +} + + +#' Construct an organized dataframe with outcome data in long format +#' from augsynth object +#' +#' @noRd +get_long_data <- function( augsynth ) { + + wide_data <- augsynth$data + + tx_id <- augsynth$trt_unit + control_ids = which( wide_data$trt == 0 ) + + all_ids = 1:nrow(wide_data$X) + all_ids[ control_ids ] = rownames( augsynth$weights ) + all_ids[ which( wide_data$trt == 1 ) ] = tx_id + + df <- bind_cols(wide_data$X, wide_data$y) %>% + mutate(!!as.name(augsynth$unit_var) := all_ids) %>% + pivot_longer(!augsynth$unit_var, + names_to = augsynth$time_var, + values_to = 'Yobs', + ) %>% + mutate(!!as.name(augsynth$time_var) := as.numeric(!!as.name(augsynth$time_var)), + ever_Tx = ifelse(!!as.name(augsynth$unit_var) == augsynth$trt_unit, 1, 0)) + + df +} + + + +add_placebo_distribution <- function(augsynth) { + + # Run permutations + ests <- get_placebo_gaps(augsynth, att = FALSE) + time_cols = 2:ncol(ests) + + ests$trt = augsynth$data$trt + + lest = ests %>% + pivot_longer( cols = all_of( time_cols ), + names_to = augsynth$time_var, values_to = "Yhat" ) %>% + mutate(!!as.name(augsynth$time_var) := as.numeric(!!as.name(augsynth$time_var))) + + df <- get_long_data( augsynth ) + + ##### Make dataset of donor units with their weights ##### + units = dplyr::select( df, !!as.name(augsynth$unit_var), ever_Tx ) %>% + unique() + + tx_col = augsynth$unit_var + weights = data.frame( tx_col = rownames( augsynth$weights ), + weights = augsynth$weights, + stringsAsFactors = FALSE) + colnames(weights)[1] <- augsynth$unit_var + + units = merge( units, weights, by=augsynth$unit_var, all=TRUE ) + + # Zero weights for tx units. + units$weights[ units$ever_Tx == 1 ] = 0 + + # confirm that we have placebos for every entity over every observed time period + stopifnot("Placebos do not cover every entity over every observed time period." = (ncol(augsynth$data$X) + ncol(augsynth$data$y)) * nrow(augsynth$data$y) == nrow(lest)) + + lest = rename( lest, !!as.name(augsynth$unit_var) := ID ) + + nn = nrow(lest) + lest = left_join( lest, + df[ c(augsynth$unit_var, augsynth$time_var, "Yobs") ], # issue is that this is calling for original data + by = names(lest)[names(lest) %in% names(df[ c(augsynth$unit_var, augsynth$time_var, "Yobs") ])] ) + stopifnot( nrow( lest ) == nn ) + + + # Impact is difference between observed and imputed control-side outcome + lest$impact = lest$Yobs - lest$Yhat + + #### Make the actual treatment result information ##### + + # Get our observed series (the treated unit) + T_tract = filter( lest, trt == 1 ) + + # The raw donor pool average series + averages = df %>% + filter( ever_Tx == 0 ) %>% + group_by( !!as.name(augsynth$time_var) ) %>% + summarise( raw_average = mean( .data$Yobs ), + .groups = "drop" ) + T_tract = left_join( T_tract, averages, by = augsynth$time_var ) + + ##### Calculate ATT and MDES ##### + + lest = rename( lest, ATT = impact ) + t0 <- ncol(augsynth$data$X) + treat_year = augsynth$data$time[t0 + 1] + + res = calculate_MDES_table(lest, treat_year, augsynth$unit_var, augsynth$time_var) + + # Add RMSPE to donor list + units = left_join( units, res$RMSPEs, by = names(units)[names(units) %in% names(res$RMSPEs)]) + + MDES_table = mutate( res$MDES_table, + raw_average = T_tract$raw_average, + tx = ifelse( !!as.name(augsynth$time_var) >= treat_year, 1, 0 ) ) %>% + relocate( raw_average, .after = Yhat ) + + augsynth$results$permutations <- list( placebo_dist = res$rstatcalc, + MDES_table = MDES_table) + + return(augsynth) +} + + +#' Generate formatted outputs for statistical inference using permutation inference +#' +#' @param augsynth An augsynth object. +#' +#' @noRd +permutation_inf <- function(augsynth, inf_type) { + + t0 <- dim(augsynth$data$synth_data$Z0)[1] + tpost <- dim(augsynth$data$synth_data$Z0)[1] + + out <- list() + out$att <- augsynth$results$permutations$MDES_table$ATT + + SEg = NA + if (inf_type == 'permutation') { + SEg = augsynth$results$permutations$MDES_table$SE_gap + pval = augsynth$results$permutations$MDES_table$p_gap + } else if (inf_type == 'permutation_rstat') { + SEg = augsynth$results$permutations$MDES_table$SE_rstat + pval = augsynth$results$permutations$MDES_table$p_rstat + } + out$lb <- out$att + (qnorm(0.025) * SEg) + out$ub <- out$att + (qnorm(0.975) * SEg) + out$p_val <- pval + + out$lb[c(1:t0)] <- NA + out$ub[c(1:t0)] <- NA + out$p_val[c(1:t0)] <- NA + out$alpha <- 0.05 + + return(out) +} + + + + + +#' RMSPE for treated unit +#' +#' @param augsynth Augsynth object +#' @return RMSPE (Root mean squared predictive error) for the treated unit in pre-treatment era +#' +#' @export +RMSPE <- function( augsynth ) { + stopifnot( is.augsynth(augsynth) ) + + pd = predict( augsynth, att = TRUE ) + sqrt( mean( pd[1:ncol(augsynth$data$X)]^2 ) ) +} + + + + + +#' Get placebo distribution +#' +#' @param augsynth Augsynth object or summery object with permutation inference of some sort. +#' +#' @return Data frame holding the placebo distribution, one row per placebo unit and time point. +#' +#' @export +placebo_distribution <- function( augsynth ) { + inf_type = NA + if ( is_summary_augsynth(augsynth) ) { + inf_type = augsynth$inf_type + } else if ( is.augsynth(augsynth) ) { + augsynth <- summary( augsynth, inf_type = "permutation" ) + inf_type = augsynth$inf_type + } else { + stop( "Object must be an Augsynth object or summary object" ) + } + + if ( !is.null( inf_type ) && inf_type %in% c( "permutation", "permutation_rstat" ) ) { + return( augsynth$permutations$placebo_dist ) + } else { + stop( "Placebo distribution only available for permutation inference" ) + } +} + + + + + + + +#' Return a summary data frame for the treated unit +#' +#' @param augsynth Augsynth object of interest +#' +#' @return Dataframe of information about the treated unit, one row +#' per time point. This includes the measured outcome, predicted +#' outcome from the synthetic unit, the average of all donor units +#' (as reference, called `raw_average`), and the estimated impact +#' (`ATT`), and the r-statistic (ATT divided by RMSPE). +#' +#' @seealso [donor_table()] +#' @export +treated_table <- function(augsynth) { + + if ( is_summary_augsynth( augsynth ) ) { + return( augsynth$treated_table ) + } + + # Calculate the time series of the treated, the synthetic control, + # and the overall donor pool average + trt_index <- which(augsynth$data$trt == 1) + df <- bind_cols(augsynth$data$X, augsynth$data$y) + # synth_unit <- t(df[-trt_index, ]) %*% augsynth$weights + synth_unit <- predict(augsynth) + average_unit <- df[-trt_index, ] %>% colMeans() + treated_unit <- t(df[trt_index, ]) + lvls = tibble( + time = as.numeric( colnames(df) ), + Yobs = as.numeric( treated_unit ), + Yhat = as.numeric( synth_unit ), + raw_average = as.numeric( average_unit ) + ) + + #lvls <- df %>% + # group_by( !!sym(augsynth$time_var ), ever_Tx) %>% + # summarise( Yavg = mean( Yobs ), .groups="drop" ) %>% + # pivot_wider( names_from = ever_Tx, values_from = Yavg ) + #colnames(lvls)[2:3] <- c("raw_average", "Yobs") + + t0 <- ncol(augsynth$data$X) + tpost <- ncol(augsynth$data$y) + lvls$tx = rep( c(0,1), c( t0, tpost ) ) + #lvls$Yhat = predict( augsynth ) + lvls$ATT = lvls$Yobs - lvls$Yhat + lvls$rstat = lvls$ATT / sqrt( mean( lvls$ATT[ lvls$tx == 0 ]^2 ) ) + + lvls <- dplyr::relocate( lvls, + time, tx, Yobs, Yhat, raw_average, ATT, rstat ) + + return( lvls ) +} + + + + diff --git a/README.md b/README.md index f7a3736..f992fb4 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,11 @@ ## Overview This package implements the Augmented Synthetic Control Method (ASCM). +In particular, there are three types of ASCM implemented in this package: + +1. **single_augsynth**: The ASCM version of the classic synthetic control approach with a single treated unit. +2. **multisynth**: ASCM that estimates the treatment effect for multiple treated units with staggered adoption. +3. **augsynth_multiout**: ASCM for a single treated unit with multiple outcomes. For a more detailed description of the main functionality check out: - [the vignette for simultaneous adoption](https://github.com/ebenmichael/augsynth/blob/master/vignettes/singlesynth-vignette.md) diff --git a/man/RMSPE.Rd b/man/RMSPE.Rd new file mode 100644 index 0000000..2e6eaaa --- /dev/null +++ b/man/RMSPE.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/permutation.R +\name{RMSPE} +\alias{RMSPE} +\title{RMSPE for treated unit} +\usage{ +RMSPE(augsynth) +} +\arguments{ +\item{augsynth}{Augsynth object} +} +\value{ +RMSPE (Root mean squared predictive error) for the treated unit in pre-treatment era +} +\description{ +RMSPE for treated unit +} diff --git a/man/add_inference.Rd b/man/add_inference.Rd new file mode 100644 index 0000000..8345c4e --- /dev/null +++ b/man/add_inference.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/augsynth.R +\name{add_inference} +\alias{add_inference} +\title{Function to add inference to augsynth object} +\usage{ +add_inference(object, inf_type = "conformal", linear_effect = F, ...) +} +\arguments{ +\item{object}{augsynth object} + +\item{inf_type}{Type of inference algorithm. Options are +\itemize{ + \item{"conformal"}{Conformal inference (default)} + \item{"jackknife+"}{Jackknife+ algorithm over time periods} + \item{"jackknife"}{Jackknife over units} + \item{"permutation"}{Placebo permutation, raw ATT} + \item{"permutation_rstat"}{Placebo permutation, RMSPE adjusted ATT} +}} + +\item{linear_effect}{Boolean, whether to invert the conformal inference hypothesis test to get confidence intervals for a linear-in-time treatment effect: intercept + slope * time} + +\item{...}{Optional arguments for inference, for more details for each `inf_type` see +\itemize{ + \item{"conformal"}{`conformal_inf`} + \item{"jackknife+"}{`time_jackknife_plus`} + \item{"jackknife"}{`jackknife_se_single`} + \item{"permutation"}{`permutation_inf`} +}} +} +\description{ +Function to add inference to augsynth object +} diff --git a/man/augsynth-package.Rd b/man/augsynth-package.Rd index 737cc24..c6fdcfe 100644 --- a/man/augsynth-package.Rd +++ b/man/augsynth-package.Rd @@ -7,3 +7,7 @@ \description{ A package implementing the Augmented Synthetic Controls Method } +\author{ +\strong{Maintainer}: Eli Ben-Michael \email{ebenmichael@berkeley.edu} + +} diff --git a/man/augsynth.Rd b/man/augsynth.Rd index cd5d46a..722fa2a 100644 --- a/man/augsynth.Rd +++ b/man/augsynth.Rd @@ -15,7 +15,8 @@ augsynth(form, unit, time, data, t_int = NULL, ...) \item{data}{Panel data as dataframe} -\item{t_int}{Time of intervention (used for single-period treatment only)} +\item{t_int}{Time of intervention (used for single-period treatment +only)} \item{...}{Optional arguments \itemize{ @@ -23,7 +24,7 @@ augsynth(form, unit, time, data, t_int = NULL, ...) \itemize{ \item{"progfunc"}{What function to use to impute control outcomes: Ridge=Ridge regression (allows for standard errors), None=No outcome model, EN=Elastic Net, RF=Random Forest, GSYN=gSynth, MCP=MCPanel, CITS=CITS, CausalImpact=Bayesian structural time series with CausalImpact, seq2seq=Sequence to sequence learning with feedforward nets} \item{"scm"}{Whether the SCM weighting function is used} - \item{"fixedeff"}{Whether to include a unit fixed effect, default F } + \item{"fixedeff"}{Whether to include a unit fixed effect, default is FALSE } \item{"cov_agg"}{Covariate aggregation functions, if NULL then use mean with NAs omitted} } \item Multi period (staggered) augsynth @@ -39,12 +40,18 @@ augsynth(form, unit, time, data, t_int = NULL, ...) }} } \value{ -augsynth object that contains: +augsynth or multisynth object (depending on dispatch) that contains (among other things): \itemize{ \item{"weights"}{weights} \item{"data"}{Panel data as matrices} } } \description{ -Fit Augmented SCM +The general `augsynth()` method dispatches to `single_augsynth`, +`augsynth_multiout`, or `multisynth` based on the number of +outcomes and treatment times. See documentation for these methods +for further detail. +} +\seealso{ +`single_augsynth`, `augsynth_multiout`, `multisynth` } diff --git a/man/augsynth_class.Rd b/man/augsynth_class.Rd new file mode 100644 index 0000000..66af8ef --- /dev/null +++ b/man/augsynth_class.Rd @@ -0,0 +1,39 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/augsynth.R +\name{is.augsynth} +\alias{is.augsynth} +\alias{dim.augsynth} +\alias{n_unit.augsynth} +\alias{n_time.augsynth} +\alias{n_treated.augsynth} +\title{Methods for accessing details of augsynth result object (of class augsynth)} +\usage{ +is.augsynth(x) + +\method{dim}{augsynth}(x, ...) + +\method{n_unit}{augsynth}(x, ...) + +\method{n_time}{augsynth}(x, ...) + +\method{n_treated}{augsynth}(x, ...) +} +\arguments{ +\item{x}{augsynth result object} +} +\value{ +is.augsynth: TRUE if object is a augsynth object. + +dim: Dimension of data as pair of (# units, # time points). + +Single number (of unique units). + +Single number (of unique time points). + +Number of treated units (always 1 for augsynth) +} +\description{ +Methods for accessing details of augsynth result object (of class augsynth) + +Number of time points in augsynth +} diff --git a/man/bootstrap_methods.Rd b/man/bootstrap_methods.Rd new file mode 100644 index 0000000..d6946ae --- /dev/null +++ b/man/bootstrap_methods.Rd @@ -0,0 +1,32 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/inference.R +\name{bootstrap_methods} +\alias{bootstrap_methods} +\alias{rdirichlet_b} +\alias{rmultinom_b} +\alias{rwild_b} +\title{Bootstrap Functions} +\usage{ +rdirichlet_b(n) + +rmultinom_b(n) + +rwild_b(n) +} +\arguments{ +\item{n}{Number of units} +} +\description{ +There are several helper functions used for bootstrap inference. +Each method returns a list of selection weights of length n, +which weights that sum to n. +} +\section{Functions}{ +\itemize{ +\item \code{rdirichlet_b()}: Bayesian bootstrap + +\item \code{rmultinom_b()}: Non-parametric bootstrap + +\item \code{rwild_b()}: Wild bootstrap (Mammen 1993) + +}} diff --git a/man/covariate_balance_table.Rd b/man/covariate_balance_table.Rd new file mode 100644 index 0000000..ce4b531 --- /dev/null +++ b/man/covariate_balance_table.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/calc_covariate_balance.R +\name{covariate_balance_table} +\alias{covariate_balance_table} +\title{Make covariate balance table.} +\usage{ +covariate_balance_table(ascm, pre_period = NULL) +} +\arguments{ +\item{ascm}{An augsynth result object from single_augsynth} + +\item{pre_period}{List of names of the pre-period timepoints to +calculate balance for. NULL means none.} +} +\description{ +Make a table comparing means of covariates in the treated group, +the raw control group, and the new weighted control group (the +synthetic control) +} diff --git a/man/donor_table.Rd b/man/donor_table.Rd new file mode 100644 index 0000000..3fdfa47 --- /dev/null +++ b/man/donor_table.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/donor_control.R +\name{donor_table} +\alias{donor_table} +\title{Return a summary data frame donor units used in the model with +their synthetic weights.} +\usage{ +donor_table(augsynth, include_RMSPE = TRUE, zap_weights = 1e-07) +} +\arguments{ +\item{augsynth}{Augsynth object to be plotted} + +\item{include_RMSPE}{Include RMSPEs in the table even if +permutation inference has not yet been conducted.} + +\item{zap_weights}{all weights smaller than this value will be set +to zero. Set to NULL to keep all weights.} +} +\description{ +If permutation inference has been conducted, table will also +include RMSPEs. This can be forced with include_RMSPE flag. +} +\details{ +If the augsynth object does not have permutation-based inference +results, the function will call that form of inference, in order to +calculate the RMSPEs for each donor unit in turn. +} diff --git a/man/make_V_matrix.Rd b/man/make_V_matrix.Rd deleted file mode 100644 index 30934d6..0000000 --- a/man/make_V_matrix.Rd +++ /dev/null @@ -1,11 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/fit_synth.R -\name{make_V_matrix} -\alias{make_V_matrix} -\title{Make a V matrix from a vector (or null)} -\usage{ -make_V_matrix(t0, V) -} -\description{ -Make a V matrix from a vector (or null) -} diff --git a/man/multisynth_class.Rd b/man/multisynth_class.Rd new file mode 100644 index 0000000..09719c5 --- /dev/null +++ b/man/multisynth_class.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/multisynth_class.R +\name{dim.multisynth} +\alias{dim.multisynth} +\alias{n_unit.multisynth} +\alias{n_time.multisynth} +\alias{n_treated.multisynth} +\title{Number of units in multisynth} +\usage{ +\method{dim}{multisynth}(x, ...) + +\method{n_unit}{multisynth}(x, ...) + +\method{n_time}{multisynth}(x, ...) + +\method{n_treated}{multisynth}(x, ...) +} +\value{ +dim: Dimension of data as pair of (# units, # time points). + +Single number (of unique units). + +Single number (of unique time points). + +Number of treated units (always 1 for multisynth) +} +\description{ +Number of units in multisynth + +Number of time points in multisynth +} diff --git a/man/permutation_plot.Rd b/man/permutation_plot.Rd new file mode 100644 index 0000000..e3ec55e --- /dev/null +++ b/man/permutation_plot.Rd @@ -0,0 +1,26 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/augsynth.R +\name{permutation_plot} +\alias{permutation_plot} +\title{Generate permutation plots} +\usage{ +permutation_plot(augsynth, inf_type = "permutation") +} +\arguments{ +\item{inf_type}{Inference type (takes a value of 'permutation' or +'permutation_rstat') Type of inference algorithm. Inherits +inf_type from `object` or otherwise defaults to "conformal". +Options are + \itemize{ + \item{"If numeric"}{A multiple of the treated unit's RMSPE + above which donor units will be dropped} + \item{"If a character"}{The name or names of donor units to + be dropped based on the `unit` parameter + in the augsynth model} + }} + +\item{results}{Results from calling augsynth()} +} +\description{ +Generate permutation plots +} diff --git a/man/placebo_distribution.Rd b/man/placebo_distribution.Rd new file mode 100644 index 0000000..ebcab05 --- /dev/null +++ b/man/placebo_distribution.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/permutation.R +\name{placebo_distribution} +\alias{placebo_distribution} +\title{Get placebo distribution} +\usage{ +placebo_distribution(augsynth) +} +\arguments{ +\item{augsynth}{Augsynth object or summery object with permutation inference of some sort.} +} +\value{ +Data frame holding the placebo distribution, one row per placebo unit and time point. +} +\description{ +Get placebo distribution +} diff --git a/man/plot.augsynth.Rd b/man/plot.augsynth.Rd index 0c8f052..6da60f6 100644 --- a/man/plot.augsynth.Rd +++ b/man/plot.augsynth.Rd @@ -4,15 +4,31 @@ \alias{plot.augsynth} \title{Plot function for augsynth} \usage{ -\method{plot}{augsynth}(x, inf = T, cv = F, ...) +\method{plot}{augsynth}(augsynth, cv = FALSE, plot_type = "estimate", inf_type = NULL, ...) } \arguments{ -\item{x}{Augsynth object to be plotted} - -\item{inf}{Boolean, whether to get confidence intervals around the point estimates} +\item{augsynth}{Augsynth or summary.augsynth object to be plotted} \item{cv}{If True, plot cross validation MSE against hyper-parameter, otherwise plot effects} +\item{plot_type}{The stylized plot type to be returned. Options include +\itemize{ + \item{"estimate"}{The ATT and 95\% confidence interval} + \item{"estimate only"}{The ATT without a confidence interval} + \item{"outcomes"}{The level of the outcome variable for the treated and synthetic control units.} + \item{"outcomes raw average"}{The level of the outcome variable for the treated and synthetic control units, along with the raw average of the donor units.} + \item{"placebo"}{The ATTs resulting from placebo tests on the donor units.} }} + +\item{inf_type}{Type of inference algorithm. Inherits inf_type from `object` or otherwise defaults to "conformal". Options are +\itemize{ + \item{"conformal"}{Conformal inference (default)} + \item{"jackknife+"}{Jackknife+ algorithm over time periods} + \item{"jackknife"}{Jackknife over units} + \item{"permutation"}{Placebo permutation, raw ATT} + \item{"permutation_rstat"}{Placebo permutation, RMSPE adjusted ATT} + \item{"None"}{Return ATT Estimate only} +}} + \item{...}{Optional arguments} } \description{ diff --git a/man/plot.summary.augsynth.Rd b/man/plot.summary.augsynth.Rd index 5d04bab..a811d9e 100644 --- a/man/plot.summary.augsynth.Rd +++ b/man/plot.summary.augsynth.Rd @@ -2,17 +2,29 @@ % Please edit documentation in R/augsynth.R \name{plot.summary.augsynth} \alias{plot.summary.augsynth} -\title{Plot function for summary function for augsynth} +\title{Plot results from summarized augsynth} \usage{ -\method{plot}{summary.augsynth}(x, inf = T, ...) +\method{plot}{summary.augsynth}(x, plot_type = "estimate", ...) } \arguments{ -\item{x}{Summary object} +\item{x}{summary.augsynth object} -\item{inf}{Boolean, whether to plot confidence intervals} +\item{plot_type}{The stylized plot type to be returned. Options include +\itemize{ + \item{"cv"}{Cross-validation diagnostic plot} + \item{"estimate"}{The ATT and 95\% confidence interval} + \item{"estimate only"}{The ATT without a confidence interval} + \item{"outcomes"}{The level of the outcome variable for the treated and synthetic control units.} + \item{"outcomes raw average"}{The level of the outcome variable for the treated and synthetic control units, along with the raw average of the donor units.} + \item{"placebo"}{The ATTs resulting from placebo tests on the donor units.} }} \item{...}{Optional arguments} } \description{ -Plot function for summary function for augsynth +Make a variety of plots, depending on plot_type. Default is to +plot impacts with associated uncertainty (if present). Other +options are estimates with no uncertainty ("estimate only"), the +level of outcome ("outcomes"), level of outcomes with the raw trend +as well ("outcomes raw average"), and the classic spagetti plot +("placebo"). } diff --git a/man/plot_augsynth_results.Rd b/man/plot_augsynth_results.Rd new file mode 100644 index 0000000..b118ec8 --- /dev/null +++ b/man/plot_augsynth_results.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/augsynth.R +\name{plot_augsynth_results} +\alias{plot_augsynth_results} +\title{Plot function for augsynth or summary.augsynth objects} +\usage{ +plot_augsynth_results(augsynth, plot_type = "estimate", inf_type = NULL, ...) +} +\arguments{ +\item{augsynth}{augsynth or summary.augsynth object to be plotted} + +\item{plot_type}{The stylized plot type to be returned. Options include +\itemize{ + \item{"cv"}{Cross-validation diagnostic plot} + \item{"estimate"}{The ATT and 95\% confidence interval} + \item{"estimate only"}{The ATT without a confidence interval} + \item{"outcomes"}{The level of the outcome variable for the treated and synthetic control units.} + \item{"outcomes raw average"}{The level of the outcome variable for the treated and synthetic control units, along with the raw average of the donor units.} + \item{"placebo"}{The ATTs resulting from placebo tests on the donor units.} }} + +\item{inf_type}{Type of inference algorithm. Inherits inf_type from `object` or otherwise defaults to "conformal". Options are +\itemize{ + \item{"conformal"}{Conformal inference (default)} + \item{"jackknife+"}{Jackknife+ algorithm over time periods} + \item{"jackknife"}{Jackknife over units} + \item{"permutation"}{Placebo permutation, raw ATT} + \item{"permutation_rstat"}{Placebo permutation, RMSPE adjusted ATT} + \item{"None"}{Return ATT Estimate only} +}} + +\item{...}{Optional arguments} + +\item{cv}{If True, plot cross validation MSE against hyper-parameter, otherwise plot effects} +} +\description{ +Plot function for augsynth or summary.augsynth objects +} diff --git a/man/rdirichlet_b.Rd b/man/rdirichlet_b.Rd deleted file mode 100644 index 1a498da..0000000 --- a/man/rdirichlet_b.Rd +++ /dev/null @@ -1,14 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/inference.R -\name{rdirichlet_b} -\alias{rdirichlet_b} -\title{Bayesian bootstrap} -\usage{ -rdirichlet_b(n) -} -\arguments{ -\item{n}{Number of units} -} -\description{ -Bayesian bootstrap -} diff --git a/man/rmultinom_b.Rd b/man/rmultinom_b.Rd deleted file mode 100644 index d6bf719..0000000 --- a/man/rmultinom_b.Rd +++ /dev/null @@ -1,14 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/inference.R -\name{rmultinom_b} -\alias{rmultinom_b} -\title{Non-parametric bootstrap} -\usage{ -rmultinom_b(n) -} -\arguments{ -\item{n}{Number of units} -} -\description{ -Non-parametric bootstrap -} diff --git a/man/rwild_b.Rd b/man/rwild_b.Rd deleted file mode 100644 index 6503f78..0000000 --- a/man/rwild_b.Rd +++ /dev/null @@ -1,14 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/inference.R -\name{rwild_b} -\alias{rwild_b} -\title{Wild bootstrap (Mammen 1993)} -\usage{ -rwild_b(n) -} -\arguments{ -\item{n}{Number of units} -} -\description{ -Wild bootstrap (Mammen 1993) -} diff --git a/man/single_augsynth.Rd b/man/single_augsynth.Rd index 85654d1..4c6b7bf 100644 --- a/man/single_augsynth.Rd +++ b/man/single_augsynth.Rd @@ -32,17 +32,19 @@ single_augsynth( ridge=Ridge regression (allows for standard errors), none=No outcome model, en=Elastic Net, RF=Random Forest, GSYN=gSynth, -mcp=MCPanel, +mcp=MCPanel, cits=Comparitive Interuppted Time Series causalimpact=Bayesian structural time series with CausalImpact} -\item{scm}{Whether the SCM weighting function is used} +\item{scm}{Whether the SCM weighting function is used. If FALSE, then package will fit the outcome model, but not calculate new donor weights to match pre-treatment covariates. Instead, each donor unit will be equally weighted. If TRUE, weights on donor pool will be calculated.} \item{fixedeff}{Whether to include a unit fixed effect, default F} \item{cov_agg}{Covariate aggregation functions, if NULL then use mean with NAs omitted} \item{...}{optional arguments for outcome model} + +\item{plot}{Whether or not to return a plot of the augsynth model} } \value{ augsynth object that contains: diff --git a/man/summary.augsynth.Rd b/man/summary.augsynth.Rd index 8d1c7cd..9e0f62f 100644 --- a/man/summary.augsynth.Rd +++ b/man/summary.augsynth.Rd @@ -4,29 +4,35 @@ \alias{summary.augsynth} \title{Summary function for augsynth} \usage{ -\method{summary}{augsynth}(object, inf = T, inf_type = "conformal", linear_effect = F, ...) +\method{summary}{augsynth}(object, inf_type = "conformal", ...) } \arguments{ \item{object}{augsynth object} -\item{inf}{Boolean, whether to get confidence intervals around the point estimates} +\item{inf_type}{Type of inference algorithm. If left NULL, inherits +inf_type from `object` or otherwise defaults to "conformal." +Options are + \itemize{ + \item{"conformal"}{Conformal inference (default)} + \item{"jackknife+"}{Jackknife+ algorithm over time periods} + \item{"jackknife"}{Jackknife over units} + \item{"permutation"}{Placebo permutation, raw ATT} + \item{"permutation_rstat"}{Placebo permutation, RMSPE adjusted ATT} + }} -\item{inf_type}{Type of inference algorithm. Options are -\itemize{ - \item{"conformal"}{Conformal inference (default)} - \item{"jackknife+"}{Jackknife+ algorithm over time periods} - \item{"jackknife"}{Jackknife over units} -}} - -\item{linear_effect}{Boolean, whether to invert the conformal inference hypothesis test to get confidence intervals for a linear-in-time treatment effect: intercept + slope * time} - -\item{...}{Optional arguments for inference, for more details for each `inf_type` see -\itemize{ - \item{"conformal"}{`conformal_inf`} - \item{"jackknife+"}{`time_jackknife_plus`} - \item{"jackknife"}{`jackknife_se_single`} -}} +\item{...}{Optional arguments for inference, for more details for +each `inf_type` see + \itemize{ + \item{"conformal"}{`conformal_inf`} + \item{"jackknife+"}{`time_jackknife_plus`} + \item{"jackknife"}{`jackknife_se_single`} + \item{"permutation", "permutation_rstat"}{`permutation_inf`} + }} } \description{ -Summary function for augsynth +Summary summarizes an augsynth result by (usually) adding an +inferential result, if that has not been calculated already, and +calculating a few other summary statistics such as estimated bias. +This method does this via `add_inference()`, if inference is +needed. } diff --git a/man/summary.augsynth_class.Rd b/man/summary.augsynth_class.Rd new file mode 100644 index 0000000..14eec75 --- /dev/null +++ b/man/summary.augsynth_class.Rd @@ -0,0 +1,17 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/augsynth.R +\name{is_summary_augsynth} +\alias{is_summary_augsynth} +\title{Methods for accessing details of summary.augsynth object} +\usage{ +is_summary_augsynth(x) +} +\arguments{ +\item{x}{summary.augsynth result object} +} +\value{ +is_summary_augsynth: TRUE if object is a augsynth object. +} +\description{ +Methods for accessing details of summary.augsynth object +} diff --git a/man/synth_class.Rd b/man/synth_class.Rd new file mode 100644 index 0000000..feab539 --- /dev/null +++ b/man/synth_class.Rd @@ -0,0 +1,26 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/augsynth.R +\name{n_unit} +\alias{n_unit} +\alias{n_time} +\alias{n_treated} +\title{Number of time points in fit data} +\usage{ +n_unit(x, ...) + +n_time(x, ...) + +n_treated(x, ...) +} +\value{ +Single number (of unique units). + +Single number (of unique time points). + +Single number (of number of treated units). +} +\description{ +Number of time points in fit data + +Number of treated units in fit data +} diff --git a/man/treated_table.Rd b/man/treated_table.Rd new file mode 100644 index 0000000..95d2b0e --- /dev/null +++ b/man/treated_table.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/permutation.R +\name{treated_table} +\alias{treated_table} +\title{Return a summary data frame for the treated unit} +\usage{ +treated_table(augsynth) +} +\arguments{ +\item{augsynth}{Augsynth object of interest} +} +\value{ +Dataframe of information about the treated unit, one row + per time point. This includes the measured outcome, predicted + outcome from the synthetic unit, the average of all donor units + (as reference, called `raw_average`), and the estimated impact + (`ATT`), and the r-statistic (ATT divided by RMSPE). +} +\description{ +Return a summary data frame for the treated unit +} +\seealso{ +[donor_table()] +} diff --git a/man/update_augsynth.Rd b/man/update_augsynth.Rd new file mode 100644 index 0000000..5d1b3ed --- /dev/null +++ b/man/update_augsynth.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/donor_control.R +\name{update_augsynth} +\alias{update_augsynth} +\title{Return a new augsynth object with specified donor units removed} +\usage{ +update_augsynth(augsynth, drop = 20) +} +\arguments{ +\item{augsynth}{Augsynth object to be plotted} + +\item{drop}{Drop donor units, based on pre-treatment RMSPE or unit +name(s). Default of 20 means drop units with an RMSPE 20x higher +than the treated unit. The `drop` parameter can also be a character vector of unit +IDs to drop.} +} +\description{ +Return a new augsynth object with specified donor units removed +} diff --git a/pkg.Rproj b/pkg.Rproj index d848a9f..bfa3107 100644 --- a/pkg.Rproj +++ b/pkg.Rproj @@ -5,8 +5,13 @@ SaveWorkspace: No AlwaysSaveHistory: Default EnableCodeIndexing: Yes +UseSpacesForTab: Yes +NumSpacesForTab: 4 Encoding: UTF-8 +RnwWeave: Sweave +LaTeX: pdfLaTeX + AutoAppendNewline: Yes StripTrailingWhitespace: Yes diff --git a/tests/testthat/test-augsynth_plotting.R b/tests/testthat/test-augsynth_plotting.R new file mode 100644 index 0000000..a9f0bef --- /dev/null +++ b/tests/testthat/test-augsynth_plotting.R @@ -0,0 +1,86 @@ +library( tidyverse ) +data(basque, package = "Synth") + +context("Test plotting features of single augsynth objects") + +basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0, + regionno != 17 ~ 0, + regionno == 17 ~ 1)) %>% + filter(regionno != 1) + + +syn <- single_augsynth(gdpcap ~ trt, regionno, year, 1975, basque, + progfunc = "None", scm = T, fixedeff = F) +sum <- summary( syn, inf_type = "permutation" ) + + +test_that("Confirm equivalence of plotting estimate across augsynth and summary objects", { + + # Testing some of the permutation plotting functions + tst_plt <- augsynth:::permutation_plot( syn, inf_type = "permutation_rstat" ) + expect_true( is_ggplot( tst_plt ) ) + + + p0 <- plot( sum ) + p <- plot( sum, plot_type = "estimate" ) + expect_equal( p0, p ) + + pEO <- plot( sum, plot_type = "estimate only" ) + expect_message( + pEO1 <- plot( syn, plot_type = "estimate only", inf_type = "permutation" ) + ) + expect_equal( pEO, pEO1 ) +}) + + + +test_that("Confirm equivalence of plotting outcomes across augsynth and summary objects", { + + po <- plot( sum, plot_type = "outcomes" ) + po1 = plot( syn, plot_type = "outcomes", inf_type = "permutation" ) + expect_equal( po, po1 ) + + praw <- plot( sum, plot_type = "outcomes raw average" ) + praw1 <- plot( syn, plot_type = "outcomes raw average", inf_type="permutation" ) + expect_equal( praw$data, praw1$data ) + + # These two plots should be the same + p2 <- plot( syn ) + expect_true( is_ggplot(p2 ) ) + +}) + + +test_that("Requesting a placebo plot defaults to permutation inference", { + + syn <- single_augsynth(gdpcap ~ trt, regionno, year, 1975, basque, + progfunc = "None", scm = T, fixedeff = F) + + sum <- summary( syn, inf_type = "permutation" ) + + ppla <- plot( sum, plot_type = "placebo" ) + ppla1 <- plot( syn, plot_type = "placebo", inf_type="permutation" ) + + expect_equal( ppla, ppla1 ) + expect_true(all(ppla$data$ATT == ppla1$data$ATT)) +}) + + +test_that("Check that plotting functions will add inference on the fly", { + + syn <- single_augsynth(gdpcap ~ trt, regionno, year, 1975, basque, + progfunc = "None", scm = T, fixedeff = F) + + expect_message( auto_add_inf <- plot( syn, inf_type = "permutation_rstat" ) ) + expect_true( is_ggplot( auto_add_inf ) ) +}) + + +test_that("Check that we can use plotting helper functions from summary", { + + sum <- summary(syn) + + plt <- augsynth:::augsynth_outcomes_plot( sum ) + expect_true( is_ggplot( plt ) ) + +}) diff --git a/tests/testthat/test-donor_control.R b/tests/testthat/test-donor_control.R new file mode 100644 index 0000000..0432cdc --- /dev/null +++ b/tests/testthat/test-donor_control.R @@ -0,0 +1,116 @@ + + +set.seed(7393) +dat = augsynth:::make_synth_data( n_time = 10, n_U = 5, N = 12, long_form = TRUE, tx_impact = 2, tx_shift = 1 ) +# ggplot( dat, aes( time, Y, group=ID ) ) + geom_line() + +syn = augsynth( Y ~ Tx | X1, unit = ID, time = time, data = dat, fixedeff = TRUE, progfunc = "none" ) + +test_that("Donor control - high RMSPE multiple doesn't drop any units", { + syn2 <- update_augsynth(syn, drop = 1000) + expect_equal(nrow(donor_table(syn2)), 11) +}) + + + +test_that("Donor control drops the correct units based on RMSPE", { + + ssyn = summary( syn, inf_type="permutation" ) + if ( FALSE ) { + plot( ssyn, "outcomes raw average" ) + } + treated_table(syn) + ssyn + + dtable <- donor_table(syn) + dtable + nrow( dtable ) + trt_RMSPE <- add_inference(syn, inf_type = 'permutation')$results$permutations$placebo_dist %>% + filter(time < syn$t_int) %>% + filter(ID == 1) %>% + pull(RMSPE) %>% unique() + + drop_factor = 1.0 + drop_units <- dtable %>% + filter(RMSPE > trt_RMSPE * drop_factor) %>% + pull(ID) + drop_units + + syn3 <- update_augsynth(syn, drop = drop_factor) + d3 <- donor_table(syn3) + expect_true( all( !( drop_units %in% d3$ID ) ) ) + + # TODO: Why should the new donor table have 8 rows? I dropped + # that check in the following: + + #criteria <- (nrow(donor_table(syn3)) == 8) & (drop_units %in% $ID %>% all() == FALSE) + #expect_true(criteria) + + + # Check that we get the same model parameters from dropping units manually + dat_tmp <- dat %>% filter(!ID %in% drop_units) + syn_manual <- augsynth( Y ~ Tx | X1, unit = ID, time = time, data = dat_tmp, + fixedeff = TRUE, progfunc = "none" ) + s_syn_manual <- summary(syn_manual) + s_syn3 <- summary(syn3) + + expect_equal( s_syn_manual$att$Estimate, s_syn3$att$Estimate ) + + expect_equal( nrow(donor_table(syn_manual)), nrow(donor_table(syn3))) + + + # Check that dropping useless donors doesn't change anything + dd = setdiff( as.character( 1:12 ), dtable$ID ) + synX = update_augsynth( syn, drop = dd ) + expect_equal( syn$att$Estimate, synX$att$Estimate ) + expect_equal( nrow(donor_table(syn)), nrow(donor_table(synX))) + +}) + + + +test_that("Donor control drops the correct units based on unit names", { + drop_units <- c('2', '3', '4') + + syn4 <- update_augsynth(syn, drop = drop_units) + + expect_true( all( ! (drop_units %in% donor_table(syn4)$ID ) ) ) + + + # Check that we get the same model parameters from dropping units manually + dat_tmp <- dat %>% filter(!ID %in% drop_units) + syn_manual <- augsynth( Y ~ Tx | X1, unit = ID, time = time, + data = dat_tmp, fixedeff = TRUE, + progfunc = "none" ) + s_syn_manual <- summary(syn_manual) + s_syn4 <- summary(syn4) + + expect_equal( s_syn_manual$att$Estimate, s_syn4$att$Estimate ) + expect_equal( nrow(donor_table(syn_manual)), nrow(donor_table(syn4))) + + +}) + + +test_that("`update_augsynth` returns the same basic SCM in cases when updates are not applied", { + + syn1 <- augsynth(lngdpcapita ~ treated, fips, year_qtr, kansas, + progfunc = "None", scm = T) + syn2 <- update_augsynth(syn1, Inf) # dropping based on RMSPE multiple + syn3 <- update_augsynth(syn1, "") # dropping based on unit names + + # test that weights are the same + expect_equal(syn2$weights, syn3$weights) + expect_equal(syn1$weights, syn2$weights) + + # test that ATTs are the same + summ1 <- summary(syn1) + summ2 <- summary(syn2) + summ3 <- summary(syn3) + expect_equal(summ2$att['Estimate'], summ3$att['Estimate']) + expect_equal(summ1$att['Estimate'], summ2$att['Estimate']) +}) + + + + diff --git a/tests/testthat/test-permutation_inference.R b/tests/testthat/test-permutation_inference.R new file mode 100644 index 0000000..645c406 --- /dev/null +++ b/tests/testthat/test-permutation_inference.R @@ -0,0 +1,94 @@ + + +library( tidyverse ) + +data(basque, package="Synth") + +head( basque) + +#basque <- basque %>% +# dplyr::select( regionno, year, gdpcap ) %>% +# as_tibble() + +basque <- basque %>% + mutate(trt = case_when(year < 1975 ~ 0, + regionno != 17 ~ 0, + regionno == 17 ~ 1)) %>% + filter(regionno != 1) %>% + dplyr::select( -c( sec.agriculture:invest ) ) + +head( basque ) + +table( basque$trt, basque$regionno ) + + + +test_that( "MDES_table corresponds to default treatment table", { + + syn <- augsynth(gdpcap ~ trt, regionno, year, + data=basque, scm = TRUE) + + summ <- summary(syn, inf_type = 'permutation') + + mm <- summ$permutations$MDES_table[1:7] %>% + select( sort( names(.))) + mm + tt <- treated_table(syn) %>% + select( sort( names(.))) + tt + + expect_equal( as.data.frame(tt)[c("ATT","raw_average","Yhat", "tx")], + as.data.frame(mm)[c("ATT","raw_average","Yhat", "tx")] ) +} ) + +test_that( "Placebo distribtion works", { + + syn <- augsynth(gdpcap ~ trt, regionno, year, + data = basque, scm = TRUE, progfunc = "none" ) + tt <- treated_table(syn) + + donor_table(syn) + + b2 = basque %>% + mutate( trt = 0 + (regionno == 7 ) * (year>=1975) ) + syn7 <- augsynth(gdpcap ~ trt, regionno, year, + data = b2, scm = TRUE, progfunc = "none") + + # These should be the same? Or close to the same? + # (They were not under ridge, probably due to the cross-validation procedure.) + gaps7 <- augsynth:::get_placebo_gaps(syn7) + g17 <- gaps7 %>% + dplyr::filter( ID == 17 ) %>% + pivot_longer( cols = `1955`:`1997`, + names_to = "year", + values_to = "est" ) + expect_equal( tt$ATT - g17$est, rep(0,nrow(tt)), tolerance = 0.000001 ) + + n_unit = length(unique(basque$regionno)) + expect_equal( nrow(gaps7), n_unit ) + n_year = length(unique(basque$year)) + + expect_equal( ncol(gaps7), 1 + n_year ) + expect_equal( dim( syn ), c( n_unit, n_year ) ) + + pc <- augsynth:::add_placebo_distribution( syn ) + rs <- pc$results$permutations + expect_equal( names(rs), c("placebo_dist", "MDES_table") ) + expect_equal( nrow( rs$placebo_dist ), n_year * n_unit ) + + expect_equal( syn$unit_var, "regionno" ) + expect_equal( syn$time_var, "year" ) +}) + + + +test_that("Inference carries through in summary objects", { + + syn <- augsynth(gdpcap ~ trt, regionno, year, + data = basque, scm = TRUE, progfunc = "none" ) + + sum <- summary( syn, inf_type = "permutation" ) + expect_equal( sum$inf_type, "permutation" ) +}) + + diff --git a/tests/testthat/test-summary.augsynth.R b/tests/testthat/test-summary.augsynth.R new file mode 100644 index 0000000..878eb44 --- /dev/null +++ b/tests/testthat/test-summary.augsynth.R @@ -0,0 +1,45 @@ +library( augsynth ) +set.seed( 4040440 ) + +n_units <- 12 +n_time <- 10 + +dat = augsynth:::make_synth_data( n_time = n_time, n_U = 5, N = n_units, long_form = TRUE, tx_impact = 2, tx_shift = 1 ) +syn = augsynth( Y ~ Tx | X1 + X2 + X3 + X4 + X5, unit = ID, time = time, data = dat, progfunc="none") + +test_that( "covariate table works", { + cov = covariate_balance_table( syn ) + cov + expect_true( all( paste0("X", 1:5 ) %in% cov$variable ) ) + + expect_output( print( cov ), "variable.*Tx.*Co.*Raw") + + c2 = covariate_balance_table( syn, pre_period = c("1", "2") ) + c2 + aa <- dat$Y[ dat$ever_Tx & dat$time %in% c("1", "2") ] + expect_equal( c2$Tx[ c2$variable == "2" ], aa[[2]] ) + expect_equal( c2$Tx[ c2$variable == "1" ], aa[[1]] ) +}) + + +test_that("summary.augsynth works", { + + expect_output( print( syn ), paste0('Fit to ', n_units, ' units and ', n_time, ' time points.*Average ATT Estimate')) + + sum = summary( syn ) + + # get donor table, add up number of units without weights and then check that the number in the summary is the same as that + n_donor <- sum$donor_table %>% filter(weight > 0) %>% nrow() + expect_output( print( sum ), glue::glue("{n_donor} donor units used with weights.")) + + s2 = summary( syn, inf_type = "jackknife+" ) + n_donor2 <- s2$donor_table %>% filter(weight > 0) %>% nrow() + expect_output( print( s2 ), + glue::glue("{n_donor2} donor units used with weights.*Avg Estimated Bias: .*Jackknife\\+ over time periods")) + + + s3 = summary( syn, inf_type = "none" ) + expect_output( print( s3 ), + "Inference type: None" ) + +}) diff --git a/tests/testthat/test_augsynth_pre.R b/tests/testthat/test_augsynth_pre.R index dd031ef..61f7fb2 100644 --- a/tests/testthat/test_augsynth_pre.R +++ b/tests/testthat/test_augsynth_pre.R @@ -126,3 +126,22 @@ test_that("augsynth runs single_synth with progfunc = 'ridge' when there is a si expect_equal(syn$weights, syn_single$weights) }) + + + +test_that("Check print.augsynth.summary writes to console", { + + data(basque) + basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0, + regionno != 17 ~0, + regionno == 17 ~ 1)) %>% + filter(regionno != 1) + + syn <- augsynth(gdpcap ~ trt, regionno, year, basque, scm = T) + sum <- summary(syn) + + expect_output( print( sum ), "Conformal inference" ) +}) + + + diff --git a/tests/testthat/test_format.R b/tests/testthat/test_format.R index 58796e7..5ee6ff3 100644 --- a/tests/testthat/test_format.R +++ b/tests/testthat/test_format.R @@ -6,9 +6,9 @@ basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0, regionno != 17 ~0, regionno == 17 ~ 1)) %>% filter(regionno != 1) - + test_that("format_data creates matrices with the right dimensions", { - + dat <- format_data(quo(gdpcap), quo(trt), quo(regionno), quo(year),1975, basque) test_dim <- function(obj, d) { @@ -23,7 +23,7 @@ test_that("format_data creates matrices with the right dimensions", { test_that("format_synth creates matrices with the right dimensions", { - + dat <- format_data(quo(gdpcap), quo(trt), quo(regionno), quo(year),1975, basque) syn_dat <- format_synth(dat$X, dat$trt, dat$y) test_dim <- function(obj, d) { @@ -115,4 +115,4 @@ test_that("augsynth exits with error message if there are no never treated units TRUE ~ 0)) expect_error(augsynth(gdpcap ~ trt, regionno, year, basque2), "1996") - }) \ No newline at end of file + }) diff --git a/tests/testthat/test_general.R b/tests/testthat/test_general.R index ccacb36..f593222 100644 --- a/tests/testthat/test_general.R +++ b/tests/testthat/test_general.R @@ -1,5 +1,7 @@ + context("Generally testing the workflow for augsynth") +library( tidyverse ) library(Synth) data(basque) @@ -8,20 +10,41 @@ basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0, regionno == 17 ~ 1)) %>% filter(regionno != 1) +# fake_aug( gdpcap ~ trt, regionno, year, basque, progfunc="None", scm=T, t_int=1975 ) - test_that("SCM gives the right answer", { - syn <- augsynth(gdpcap ~ trt, regionno, year, basque, progfunc="None", scm=T, t_int=1975) + syn <- single_augsynth( gdpcap ~ trt, regionno, year, basque, + progfunc="None", scm=TRUE, t_int=1975) + sss = summary(syn, inf_type = 'none') ## average att estimate is as expected - expect_equal(-.3686, mean(summary(syn, inf = F)$att$Estimate), tolerance=1e-4) + expect_equal(-.3686, + mean(sss$att$Estimate), tolerance=1e-4) + expect_equal( syn$progfunc, "none" ) ## level of balance is as expected expect_equal(.377, syn$l2_imbalance, tolerance=1e-3) -} -) + expect_equal( dim( syn ), c( 17, 43 ) ) + expect_equal( n_unit( syn ), 17 ) + expect_equal( n_time( syn ), 43 ) + + # Try progfunc not defined. Get different answer? + syn2 <- single_augsynth(gdpcap ~ trt, regionno, year, basque, + scm=TRUE, t_int=1975) + + expect_equal( syn2$progfunc, "ridge" ) + ss = summary( syn2, inf_type = "none" ) + # different answers + expect_true( sd( ss$att$Estimate - sss$att$Estimate ) > 0.2 ) + + + # No SCM and no progfunc throws an error? + expect_error( sraw <- single_augsynth(gdpcap ~ trt, regionno, year, basque, + progfunc="none", scm=FALSE, t_int=1975) ) + +}) test_that("SCM finds the correct t_int and gives the right answer", { @@ -30,12 +53,12 @@ test_that("SCM finds the correct t_int and gives the right answer", { syn2 <- augsynth(gdpcap ~ trt, regionno, year, basque, progfunc = "None", scm = T, t_int = 1975) ## average att estimate is as expected - expect_equal(mean(summary(syn1, inf = F)$att$Estimate), - mean(summary(syn2, inf = F)$att$Estimate), tolerance=1e-4) - + expect_equal(mean(summary(syn1, inf_type = 'none')$att$Estimate), + mean(summary(syn2, inf_type = 'none')$att$Estimate), tolerance=1e-4) + ## level of balance is as expected expect_equal(syn1$l2_imbalance, syn2$l2_imbalance, tolerance=1e-3) - + } ) @@ -46,11 +69,12 @@ test_that("Ridge ASCM gives the right answer", { scm=T, lambda=8) ## average att estimate is as expected - expect_equal(-.3696, mean(summary(asyn, inf = F)$att$Estimate), tolerance=1e-3) + expect_equal(-.3696, mean(summary(asyn, inf_type = 'none')$att$Estimate), tolerance=1e-3) ## level of balance is as expected expect_equal(.373, asyn$l2_imbalance, tolerance=1e-3) + } ) @@ -59,21 +83,21 @@ test_that("Ridge ASCM gives the right answer", { test_that("SCM after residualizing covariates gives the right answer", { - covsyn_resid <- augsynth(gdpcap ~ trt | invest + popdens, - regionno, year, basque, - progfunc = "None", scm = T, - residualize = T) + covsyn_resid <- augsynth(gdpcap ~ trt | invest + popdens, + regionno, year, basque, + progfunc = "None", scm = T, + residualize = T) - ## average att estimate is as expected - expect_equal(-.1443, - mean(summary(covsyn_resid, inf = F)$att$Estimate), - tolerance = 1e-3) + ## average att estimate is as expected + expect_equal(-.1443, + mean(summary(covsyn_resid, inf_type = 'none')$att$Estimate), + tolerance = 1e-3) - ## level of balance is as expected - expect_equal(.3720, covsyn_resid$l2_imbalance, tolerance=1e-3) + ## level of balance is as expected + expect_equal(.3720, covsyn_resid$l2_imbalance, tolerance=1e-3) - # perfect auxiliary covariate balance - expect_equal(0, covsyn_resid$covariate_l2_imbalance, tolerance=1e-3) + # perfect auxiliary covariate balance + expect_equal(0, covsyn_resid$covariate_l2_imbalance, tolerance=1e-3) } ) @@ -81,13 +105,13 @@ test_that("SCM after residualizing covariates gives the right answer", { test_that("Ridge ASCM with covariates jointly gives the right answer", { covsyn_noresid <- augsynth(gdpcap ~ trt | invest + popdens, - regionno, year, basque, - progfunc = "None", scm = T, - residualize = F) + regionno, year, basque, + progfunc = "None", scm = T, + residualize = F) ## average att estimate is as expected expect_equal(-.3345, - mean(summary(covsyn_noresid, inf = F)$att$Estimate), + mean(summary(covsyn_noresid, inf_type = 'none')$att$Estimate), tolerance = 1e-3) ## level of balance is as expected @@ -104,14 +128,14 @@ test_that("Ridge ASCM with covariates jointly gives the right answer", { test_that("Ridge ASCM after residualizing covariates gives the right answer", { covascm_resid <- augsynth(gdpcap ~ trt | invest + popdens, - regionno, year, basque, - progfunc = "Ridge", scm = T, - lambda = 1, - residualize = T) + regionno, year, basque, + progfunc = "Ridge", scm = T, + lambda = 1, + residualize = T) ## average att estimate is as expected expect_equal(-.123, - mean(summary(covascm_resid, inf = F)$att$Estimate), + mean(summary(covascm_resid, inf_type = 'none')$att$Estimate), tolerance = 1e-3) ## level of balance is as expected @@ -126,14 +150,14 @@ test_that("Ridge ASCM after residualizing covariates gives the right answer", { test_that("Ridge ASCM with covariates jointly gives the right answer", { covascm_noresid <- augsynth(gdpcap ~ trt | invest + popdens, - regionno, year, basque, - progfunc = "Ridge", scm = T, - lambda = 1, - residualize = F) + regionno, year, basque, + progfunc = "Ridge", scm = T, + lambda = 1, + residualize = F) ## average att estimate is as expected expect_equal(-.267, - mean(summary(covascm_noresid, inf = F)$att$Estimate), + mean(summary(covascm_noresid, inf_type = 'none')$att$Estimate), tolerance = 1e-3) ## level of balance is as expected @@ -149,8 +173,9 @@ test_that("Ridge ASCM with covariates jointly gives the right answer", { test_that("Warning given when inputting an unused argument", { expect_warning( - augsynth(gdpcap ~ trt| invest + popdens, regionno, year, basque, - progfunc="Ridge", scm=T, lambda=8, t_int = 1975, - bad_param = "Unused input parameter"), + augsynth(gdpcap ~ trt| invest + popdens, regionno, year, basque, + progfunc="Ridge", scm=T, lambda=8, t_int = 1975, + bad_param = "Unused input parameter"), ) }) + diff --git a/tests/testthat/test_inference.R b/tests/testthat/test_inference.R new file mode 100644 index 0000000..cbc3841 --- /dev/null +++ b/tests/testthat/test_inference.R @@ -0,0 +1,15 @@ + + + +context("Test inference features of single augsynth objects") + +basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0, + regionno != 17 ~0, + regionno == 17 ~ 1)) %>% + filter(regionno != 1) + +test_that("Default model doesn't contain inference", { + syn <- augsynth(gdpcap ~ trt, regionno, year, basque, progfunc="Ridge", scm = T) + expect_equivalent(is.null(syn$results), TRUE) +}) + diff --git a/tests/testthat/test_multisynth.R b/tests/testthat/test_multisynth.R index 79008a1..c87f14a 100644 --- a/tests/testthat/test_multisynth.R +++ b/tests/testthat/test_multisynth.R @@ -1,5 +1,6 @@ context("Generally testing the workflow for multisynth") +library( tidyverse ) library(Synth) data(basque) basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0, @@ -8,14 +9,16 @@ basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0, filter(regionno != 1) - + test_that("augsynth and multisynth give the same answer for a single treated unit and no augmentation", { syn <- single_augsynth(gdpcap ~ trt, regionno, year, 1975, basque, progfunc="None", scm=T, fixedeff = F) msyn <- multisynth(gdpcap ~ trt, regionno, year, basque, nu = 0, fixedeff = F, scm=T, eps_rel=1e-5, eps_abs=1e-5) - + + expect_equal( dim( msyn ), dim( syn ) ) + # weights are the same-ish expect_equal(c(syn$weights), c(msyn$weights[-16]), tolerance=3e-4) @@ -39,6 +42,10 @@ test_that("Pooling doesn't matter for a single treated unit", { allpool <- multisynth(gdpcap ~ trt, regionno, year, basque, nu = 1, scm=T, eps_rel=1e-5, eps_abs=1e-5) + expect_equal( dim( nopool ), dim( allpool ) ) + expect_equal( n_treated(nopool), n_treated(allpool) ) + expect_equal( n_treated(nopool), 1 ) + # weights are the same expect_equal(nopool$weights, allpool$weights) @@ -56,7 +63,7 @@ test_that("Pooling doesn't matter for a single treated unit", { - + test_that("Separate synth is the same as fitting separate synths", { @@ -66,23 +73,27 @@ test_that("Separate synth is the same as fitting separate synths", { filter(regionno != 1) - basque2 %>% filter(regionno != 16) %>% + basque2 %>% filter(regionno != 16) %>% single_augsynth(gdpcap ~ trt, regionno, year, 1975, ., progfunc="None", scm=T) -> scm17 - basque2 %>% filter(regionno != 17) %>% + basque2 %>% filter(regionno != 17) %>% single_augsynth(gdpcap ~ trt, regionno, year, 1975, ., progfunc="None", scm=T) -> scm16 - + msyn <- multisynth(gdpcap ~ trt, regionno, year, basque2, nu = 0, scm=T, eps_rel=1e-5, eps_abs=1e-5, fixedeff = F) - + + expect_equal( dim( scm17 ), c(16,43) ) + expect_equal( n_treated(scm17), 1 ) + expect_equal( n_treated(msyn), 2 ) + # weights are the same-ish sscm_weights <- unname(c(scm17$weights)) mscm_weights <- unname(c(msyn$weights[-c(15, 16), 2])) expect_equal(sscm_weights, mscm_weights, tolerance=3e-2) expect_equal(rownames(scm17$weights), rownames(as.matrix(msyn$weights[-c(15, 16), 2]))) # expect_equal(c(scm16$weights), c(msyn$weights[-c(15, 16), 1]), tolerance=3e-2) - + # estimates are the same-ish pred_msyn <- predict(msyn, att=F) pred_msyn <- pred_msyn[-nrow(pred_msyn), ] @@ -101,7 +112,7 @@ test_that("Limiting number of lags works", { expect_error( multisynth(gdpcap ~ trt, regionno, year, basque2, nu = 0, - scm=T, eps_rel=1e-5, eps_abs=1e-5, n_lags =3), + scm=T, eps_rel=1e-5, eps_abs=1e-5, n_lags = 3), NA ) } @@ -117,6 +128,14 @@ test_that("L2 imbalance computed correctly", { msyn <- multisynth(gdpcap ~ trt, regionno, year, basque2, scm=T, eps_rel=1e-5, eps_abs=1e-5) + expect_true( "weights" %in% names( msyn ) ) + + ntx = length( unique( basque2$regionno[ basque2$trt == 1 ] ) ) + expect_equal( dim( msyn$weights ), c( 17, ntx ) ) + + expect_true( "data" %in% names( msyn ) ) + #msyn$data + glbl <- sqrt(mean(msyn$imbalance[,1]^2)) ind <- sqrt(mean( apply(msyn$imbalance[, -1], 2, @@ -165,7 +184,7 @@ test_that("V matrix is the same for single and multi synth", { ) - + test_that("multisynth doesn't depend on unit order", { basque2 <- basque %>% mutate(trt = case_when(year < 1975 ~ 0, @@ -180,7 +199,7 @@ test_that("multisynth doesn't depend on unit order", { basque2 %>% arrange(desc(regionno)), nu = 0, fixedeff = F, scm=T, eps_rel=1e-5, eps_abs=1e-5) - + # weights are the same expect_equal(c(msyn$weights), c(msyn2$weights)) @@ -191,7 +210,7 @@ test_that("multisynth doesn't depend on unit order", { ) - + test_that("multisynth doesn't depend on time order", { basque2 <- basque %>% mutate(trt = case_when(year < 1975 ~ 0, @@ -206,7 +225,7 @@ test_that("multisynth doesn't depend on time order", { basque2 %>% arrange(desc(year)), nu = 0, fixedeff = F, scm=T, eps_rel=1e-5, eps_abs=1e-5) - + # weights are the same expect_equal(c(msyn$weights), c(msyn2$weights)) diff --git a/tests/testthat/test_outcome_models.R b/tests/testthat/test_outcome_models.R index 2eccdb9..35633f0 100644 --- a/tests/testthat/test_outcome_models.R +++ b/tests/testthat/test_outcome_models.R @@ -10,7 +10,7 @@ basque <- basque %>% mutate(trt = case_when(year < 1975 ~ 0, filter(regionno != 1) - + test_that("Augmenting synth with glmnet runs", { if(!requireNamespace("glmnet", quietly = TRUE)) { @@ -24,7 +24,7 @@ test_that("Augmenting synth with glmnet runs", { ## should run because glmnet is installed expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc="EN", scm=T), - NA) + NA) } ) @@ -43,7 +43,7 @@ test_that("Augmenting synth with random forest runs", { ## should run because randomForest is installed expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc="RF", scm=T), - NA) + NA) } ) @@ -51,10 +51,11 @@ test_that("Augmenting synth with random forest runs", { test_that("Augmenting synth with gsynth runs and produces the correct result", { + skip("temporarily disabled") if(!requireNamespace("gsynth", quietly = TRUE)) { ## should fail because gsynth isn't installed - expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, + expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc="GSYN", scm=T), "you must install the gsynth package") @@ -64,13 +65,13 @@ test_that("Augmenting synth with gsynth runs and produces the correct result", { ## should run because gsynth is installed expect_error( - augsynth(gdpcap ~ trt, regionno, year, basque, + augsynth(gdpcap ~ trt, regionno, year, basque, progfunc = "GSYN", scm = T, CV = 0, r = 4), NA) asyn_gsyn <- augsynth(gdpcap ~ trt, regionno, year, basque, progfunc = "GSYN", scm = F, CV = 0, r = 4) - expect_equal(summary(asyn_gsyn, inf = F)$average_att$Estimate, - -0.1444637, tolerance=1e-4) + expect_equal(summary(asyn_gsyn, inf_type = 'none')$average_att$Estimate, + -0.1444637, tolerance=1e-4) } ) @@ -85,10 +86,10 @@ test_that("Augmenting synth with MCPanel runs", { } else { ## should run because MCPanel is installed expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc="MCP", scm=T), - NA) + NA) } - + } ) @@ -108,6 +109,6 @@ test_that("Augmenting synth with CausalImpact runs", { ## should run because CausalImpact is installed expect_error(augsynth(gdpcap ~ trt, regionno, year, basque, progfunc="CausalImpact", scm=T), - NA) + NA) } ) diff --git a/tests/testthat/test_unbalanced_multisynth.R b/tests/testthat/test_unbalanced_multisynth.R index 7cf8c9b..eebab67 100644 --- a/tests/testthat/test_unbalanced_multisynth.R +++ b/tests/testthat/test_unbalanced_multisynth.R @@ -122,11 +122,11 @@ test_that("Separate synth with missing control unit time drops control unit", { # drop a time period for unit 17 basque %>% filter(!regionno %in% c(18) | year != 1970) -> basque_mis - - msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, + + msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8) - msyn2 <- multisynth(gdpcap ~ trt, regionno, year, + msyn2 <- multisynth(gdpcap ~ trt, regionno, year, basque %>% filter(regionno != 18), nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8) @@ -143,10 +143,10 @@ test_that("Separate synth with missing control unit only in post-treatment perio dat_format <- format_data_stag(quo(gdpcap), quo(trt), quo(regionno), quo(year), basque_mis) expect_true(nrow(dat_format$X) == nrow(dat_format$y)) - msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, + msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8) - msyn2 <- multisynth(gdpcap ~ trt, regionno, year, + msyn2 <- multisynth(gdpcap ~ trt, regionno, year, basque %>% filter(regionno != 18), nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8) @@ -163,10 +163,10 @@ test_that("Separate synth with missing control unit only in pre-treatment period expect_true(nrow(dat_format$X) == nrow(dat_format$y)) - msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, + msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8) - msyn2 <- multisynth(gdpcap ~ trt, regionno, year, + msyn2 <- multisynth(gdpcap ~ trt, regionno, year, basque %>% filter(regionno != 18), nu = 0, scm=T, eps_rel=1e-8, eps_abs=1e-8) @@ -180,7 +180,7 @@ test_that("Multisynth with unbalanced panels runs", { basque %>% filter(!regionno %in% c(15, 17) | year != 1970) -> basque_mis - msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, + msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, scm=T, eps_rel=1e-8, eps_abs=1e-8) expect_error(summary(msyn), NA) @@ -193,7 +193,7 @@ test_that("Multisynth with unbalanced panels runs with missing post-treatment", basque %>% filter(!regionno %in% c(15, 17) | year != 1990) -> basque_mis - msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, + msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, scm=T, eps_rel=1e-8, eps_abs=1e-8) expect_error(summary(msyn), NA) @@ -207,8 +207,8 @@ test_that("Multisynth with unbalanced panels runs", { basque %>% filter(!regionno %in% c(15) | year != 1985) -> basque_mis - msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, + msyn <- multisynth(gdpcap ~ trt, regionno, year, basque_mis, scm=T, eps_rel=1e-8, eps_abs=1e-8) expect_error(summary(msyn), NA) -}) \ No newline at end of file +}) diff --git a/vignettes/singlesynth-vignette.Rmd b/vignettes/singlesynth-vignette.Rmd index 501de37..4b45fcc 100644 --- a/vignettes/singlesynth-vignette.Rmd +++ b/vignettes/singlesynth-vignette.Rmd @@ -99,10 +99,10 @@ asyn <- augsynth(lngdpcapita ~ treated, fips, year_qtr, kansas, progfunc = "Ridge", scm = T) ``` -We can plot the cross-validation MSE when dropping pre-treatment time periods by setting `cv = T` in the `plot` function: +We can plot the cross-validation MSE when dropping pre-treatment time periods by setting `plot_type = "cv"` in the `plot` function: ```{r fig_asyn_cv, fig.width=8, fig.height=4.5, echo=T, fig.align="center"} -plot(asyn, cv = T) +plot(asyn, plot_type = 'cv') ``` By default, the CV procedure chooses the maximal value of `lambda` with MSE within one standard deviation of the minimal MSE. To instead choose the `lambda` that minizes the cross validation MSE, set `min_1se = FALSE`.