Estimating individual treatment effect: generalization bounds and algorithms

There is intense interest in applying machine learning to problems of causal inference in fields such as healthcare, economics and education. In particular, individual-level causal inference has important applications such as precision medicine. We give a new theoretical analysis and family of algorithms for predicting individual treatment effect (ITE) from observational data, under the assumption known as strong ignorability. The algorithms learn a "balanced" representation such that the induced treated and control distributions look similar. We give a novel, simple and intuitive generalization-error bound showing that the expected ITE estimation error of a representation is bounded by a sum of the standard generalization-error of that representation and the distance between the treated and control distributions induced by the representation. We use Integral Probability Metrics to measure distances between distributions, deriving explicit bounds for the Wasserstein and Maximum Mean Discrepancy (MMD) distances. Experiments on real and simulated data show the new algorithms match or outperform the state-of-the-art.

PDF Abstract ICML 2017 PDF ICML 2017 Abstract


Introduced in the Paper:


Results from the Paper

Task Dataset Model Metric Name Metric Value Global Rank Result Benchmark
Causal Inference IHDP Balancing Linear Regression Average Treatment Effect Error 0.93 # 11
Causal Inference IHDP Random Forest Average Treatment Effect Error 0.96 # 12
Causal Inference IHDP k-NN Average Treatment Effect Error 0.79 # 10
Causal Inference IHDP Counterfactual Regression + WASS Average Treatment Effect Error 0.27 # 4
Causal Inference IHDP TARNet Average Treatment Effect Error 0.28 # 5
Causal Inference IHDP Balancing Neural Network Average Treatment Effect Error 0.42 # 8
Causal Inference IHDP Causal Forest Average Treatment Effect Error 0.4 # 7
Causal Inference Jobs CFR WASS Average Treatment Effect on the Treated Error 0.09 # 5
Causal Inference Jobs CFR MMD Average Treatment Effect on the Treated Error 0.08 # 3