diff --git a/causality/estimation/parametric.py b/causality/estimation/parametric.py index 997e440..064f3c3 100644 --- a/causality/estimation/parametric.py +++ b/causality/estimation/parametric.py @@ -146,3 +146,27 @@ def estimate_ATE(self, X, assignment, outcome, confounder_types, n_neighbors=5): att = estimate_ATT(self, X, assignment, outcome, confounder_types, n_neighbors=n_neighbors) atc = estimate_ATC(self, X, assignment, outcome, confounder_types, n_neighbors=n_neighbors) return (atc+att)/2. + + +class RegressionDiscontinuity(object): + def __init__ (self, robust=True): + if robust: + self.model = RLM + else: + self.model = OLM + + def estimate_ATE(self, X, continuous='continuous', outcome='outcome', cutoff=0., delta=0.1, indicator='D', + intercept='intercept', store_result=False): + slice = X[X[continuous] < cutoff + delta] + slice = slice[slice[continuous] > cutoff - delta] + slice.loc[:,continuous] = slice[continuous] - cutoff + slice.loc[:, indicator] = (slice[continuous] > 0).apply(int) + slice.loc[:, indicator+'_'+continuous] = slice[indicator] * slice[continuous] + slice.loc[:, intercept] = 1. + model = self.model(slice[outcome], slice[[intercept, indicator+'_'+continuous, indicator, continuous]]) + result = model.fit() + if store_result: + self.result = result + + def check_assumptions(self): + pass diff --git a/tests/unit/parametric.py b/tests/unit/parametric.py index f22b122..8db779b 100644 --- a/tests/unit/parametric.py +++ b/tests/unit/parametric.py @@ -91,3 +91,6 @@ def test_at_estimators(self): X = pd.DataFrame({'att': atts, 'ate': ates, 'atc': atcs}) assert (3.0 <= X.mean()).all() assert (X.mean() <= 4.0).all() + + +