diff --git a/causalml/dataset/synthetic.py b/causalml/dataset/synthetic.py index 983784fd..358ebb29 100644 --- a/causalml/dataset/synthetic.py +++ b/causalml/dataset/synthetic.py @@ -544,10 +544,10 @@ def scatter_plot_summary_holdout( for g in np.unique(group): ix = np.where(group == g)[0].tolist() - ax.scatter(xs[ix], ys[ix], c=cdict[g], label=g, s=100) + ax.scatter(xs.iloc[ix], ys.iloc[ix], c=cdict[g], label=g, s=100) for i, txt in enumerate(plot_data.label[:10]): - ax.annotate(txt, (xs[i] + 0.005, ys[i])) + ax.annotate(txt, (xs.iloc[i] + 0.005, ys.iloc[i])) ax.set_xlabel("Abs % Error of ATE") ax.set_ylabel("MSE")