Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DynamicDML #446

Merged
merged 36 commits into from
Aug 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
762e0e4
Implement DynamicDML
Mar 31, 2021
1070aea
Add performance tests and an example notebook
Apr 7, 2021
5f6da40
Add scores.
Apr 9, 2021
d8bc1f3
store some internal variables to allow calling from diased inference …
heimengqi Jun 4, 2021
6daf315
Swap t and j indexes to match the paper
Jul 29, 2021
0615d70
Update covariance matrix to include off-diagonal elements
Jul 29, 2021
7dc65b8
Add support for out of order groups
Jul 31, 2021
efd634d
Implement score
Jul 31, 2021
a508615
Merge branch 'master' into moprescu/dynamicdml
Aug 2, 2021
ac4dd70
Update docstring test outputs
Aug 2, 2021
a44a960
Fix merge issues
Aug 2, 2021
1950fd1
Address PR suggestions
Aug 2, 2021
28a92b6
Merge branch 'master' into moprescu/dynamicdml
Aug 2, 2021
4636257
Fix subscript printing in summary
Aug 3, 2021
9328a22
Address PR suggestions
Aug 5, 2021
24ca086
Update nuisance models in notebook
Aug 5, 2021
a39f1b5
Reverse effect indices to match paper
Aug 6, 2021
4210d1d
Add sample code to README
Aug 6, 2021
e7e7289
Merge branch 'master' into moprescu/dynamicdml
Aug 6, 2021
4cc1156
Adjust heterogeneity to depend only on features from the first period
Aug 6, 2021
e74067e
moved dynamic_dml to separate module. fixed remaining bugs in dgp. fi…
vsyrgkanis Aug 6, 2021
99e62e7
fixed ref in doc
vsyrgkanis Aug 6, 2021
a675564
Merge branch 'master' into moprescu/dynamicdml
vsyrgkanis Aug 6, 2021
1cf663a
doc bug
vsyrgkanis Aug 6, 2021
669c284
relaxed dynamci dml tests
vsyrgkanis Aug 6, 2021
0420656
fixed doctest
vsyrgkanis Aug 6, 2021
42c65dd
add ROI notebook
heimengqi Aug 8, 2021
8691c00
Merge branch 'master' into moprescu/dynamicdml
vsyrgkanis Aug 8, 2021
1c069e1
Merge branch 'master' into moprescu/dynamicdml
vsyrgkanis Aug 8, 2021
a5bd8e9
update setup to install jbl file
heimengqi Aug 9, 2021
e155ac6
Merge branch 'moprescu/dynamicdml' of https://github.com/microsoft/Ec…
heimengqi Aug 9, 2021
da8085d
update setup to install jbl file
heimengqi Aug 9, 2021
ebef40b
update roi notebook
heimengqi Aug 9, 2021
f0a5a29
Merge branch 'master' into moprescu/dynamicdml
heimengqi Aug 11, 2021
e77c568
Limit test paralellization
kbattocchi Aug 9, 2021
6aadad9
Merge branch 'master' into moprescu/dynamicdml
heimengqi Aug 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ For information on use cases and background material on causal inference and het
- [Interpretability](#interpretability)
- [Causal Model Selection and Cross-Validation](#causal-model-selection-and-cross-validation)
- [Inference](#inference)
- [Policy Learning](#policy-learning)
- [For Developers](#for-developers)
- [Running the tests](#running-the-tests)
- [Generating the documentation](#generating-the-documentation)
Expand Down Expand Up @@ -162,6 +163,25 @@ To install from source, see [For Developers](#for-developers) section below.

</details>

<details>
<summary>Dynamic Double Machine Learning (click to expand)</summary>

```Python
from econml.dynamic.dml import DynamicDML
# Use defaults
est = DynamicDML()
# Or specify hyperparameters
est = DynamicDML(model_y=LassoCV(cv=3),
model_t=LassoCV(cv=3),
cv=3)
est.fit(Y, T, X=X, W=None, groups=groups, inference="auto")
# Effects
treatment_effects = est.effect(X_test)
# Confidence intervals
lb, ub = est.effect_interval(X_test, alpha=0.05)
```
</details>

<details>
<summary>Causal Forests (click to expand)</summary>

Expand Down
15 changes: 15 additions & 0 deletions doc/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,21 @@ Sieve Methods
econml.iv.sieve.HermiteFeatures
econml.iv.sieve.DPolynomialFeatures

.. _dynamic_api:

Estimators for Dynamic Treatment Regimes
----------------------------------------

.. _dynamicdml_api:

Dynamic Double Machine Learning
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. autosummary::
:toctree: _autosummary

econml.dynamic.dml.DynamicDML

.. _policy_api:

Policy Learning
Expand Down
6 changes: 4 additions & 2 deletions doc/spec/estimation/dml.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ characteristics :math:`X` of the treated samples, then one can use this method.

.. testsetup::

# DML
import numpy as np
X = np.random.choice(np.arange(5), size=(100,3))
Y = np.random.normal(size=(100,2))
Expand All @@ -71,8 +72,9 @@ Most of the methods provided make a parametric form assumption on the heterogene
linear on some pre-defined; potentially high-dimensional; featurization). These methods include:
:class:`.DML`, :class:`.LinearDML`,
:class:`.SparseLinearDML`, :class:`.KernelDML`.
For fullly non-parametric heterogeneous treatment effect models, checkout the :class:`.NonParamDML`
and the :class:`.CausalForestDML`. For more options of non-parametric CATE estimators,
For fullly non-parametric heterogeneous treatment effect models, check out the :class:`.NonParamDML`
and the :class:`.CausalForestDML`.
For more options of non-parametric CATE estimators,
check out the :ref:`Forest Estimators User Guide <orthoforestuserguide>`
and the :ref:`Meta Learners User Guide <metalearnersuserguide>`.

Expand Down
95 changes: 95 additions & 0 deletions doc/spec/estimation/dynamic_dml.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
.. _dynamicdmluserguide:

===============================
Dynamic Double Machine Learning
===============================

What is it?
==================================

Dynamic Double Machine Learning is a method for estimating (heterogeneous) treatment effects when
treatments are offered over time via an adaptive dynamic policy. It applies to the case when
all potential dynamic confounders/controls (factors that simultaneously had a direct effect on the adaptive treatment
decision in the collected data and the observed outcome) are observed, but are either too many (high-dimensional) for
classical statistical approaches to be applicable or their effect on
the treatment and outcome cannot be satisfactorily modeled by parametric functions (non-parametric).
Both of these latter problems can be addressed via machine learning techniques (see e.g. [Lewis2021]_).


What are the relevant estimator classes?
========================================

This section describes the methodology implemented in the class
:class:`.DynamicDML`.
Click on each of these links for a detailed module documentation and input parameters of each class.


When should you use it?
==================================

Suppose you have observational (or experimental from an A/B test) historical data, where multiple treatment(s)/intervention(s)/action(s)
:math:`T` were offered over time to each of the units and some final outcome(s) :math:`Y` was observed and all the variables :math:`W` that could have
potentially gone into the choice of :math:`T`, and simultaneously could have had a direct effect on the outcome :math:`Y` (aka controls or confounders) are also recorder in the dataset.

If your goal is to understand what was the effect of the treatment on the outcome as a function of a set of observable
characteristics :math:`X` of the treated samples, then one can use this method. For instance call:

.. testsetup::

# DynamicDML
import numpy as np
groups = np.repeat(a=np.arange(100), repeats=3, axis=0)
W_dyn = np.random.normal(size=(300, 1))
X_dyn = np.random.normal(size=(300, 1))
T_dyn = np.random.normal(size=(300, 2))
y_dyn = np.random.normal(size=(300, ))

.. testcode::

from econml.dynamic.dml import DynamicDML
est = DynamicDML()
est.fit(y_dyn, T_dyn, X=X_dyn, W=W_dyn, groups=groups)


Class Hierarchy Structure
==================================

In this library we implement variants of several of the approaches mentioned in the last section. The hierarchy
structure of the implemented CATE estimators is as follows.

.. inheritance-diagram:: econml.dynamic.dml.DynamicDML
:parts: 1
:private-bases:
:top-classes: econml._OrthoLearner, econml._cate_estimator.LinearModelFinalCateEstimatorMixin

Below we give a brief description of each of these classes:

* **DynamicDML.** The class :class:`.DynamicDML` is an extension of the Double ML approach for treatments assigned sequentially over time periods.
This estimator will adjust for treatments that can have causal effects on future outcomes. The data corresponds to a Markov decision process :math:`\{X_t, W_t, T_t, Y_t\}_{t=1}^m`,
where :math:`X_t, W_t` corresponds to the state at time :math:`t`, :math:`T_t` is the treatment at time :math:`t` and :math:`Y_t` is the observed outcome at time :math:`t`.

The model makes the following structural equation assumptions on the data generating process:

.. math::

XW_t =~& A \cdot T_{t-1} + B \cdot XW_{t-1} + \eta_t\\
T_t =~& p(T_{t-1}, XW_t, \zeta_t) \\
Y_t =~& \theta_0(X_0)'T_t + \mu'XW_t + \epsilon_t

where :math:`XW` is the concatenation of the :math:`X` and :math:`W` variables.
For more details about this model and underlying assumptions, see [Lewis2021]_.

To learn the treatment effects of treatments in the different periods on the last period outcome, one can simply call:

.. testcode::

from econml.dynamic.dml import DynamicDML
est = DynamicDML()
est.fit(y_dyn, T_dyn, X=X_dyn, W=W_dyn, groups=groups)



Usage FAQs
==========

See our FAQ section in :ref:`DML User Guide <dmluserguide>`
11 changes: 11 additions & 0 deletions doc/spec/estimation_dynamic.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Estimation Methods for Dynamic Treatment Regimes
================================================

This section contains methods for estimating (heterogeneous) treatment effects,
even when treatments are offered over time and the treatments were chosen based on a dynamic
adaptive policy. This is referred to as the dynamic treatment regime (see e.g. [Hernan2010]_)

.. toctree::
:maxdepth: 2

estimation/dynamic_dml
12 changes: 11 additions & 1 deletion doc/spec/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,14 @@ References
.. [Lundberg2017]
Lundberg, S., Lee, S. (2017).
A Unified Approach to Interpreting Model Predictions.
URL https://arxiv.org/abs/1705.07874
URL https://arxiv.org/abs/1705.07874

.. [Lewis2021]
Lewis, G., Syrgkanis, V. (2021).
Double/Debiased Machine Learning for Dynamic Treatment Effects.
URL https://arxiv.org/abs/2002.07285

.. [Hernan2010]
Hernán, Miguel A., and James M. Robins (2010).
Causal inference.
URL https://www.hsph.harvard.edu/miguel-hernan/causal-inference-book/
1 change: 1 addition & 0 deletions doc/spec/spec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ The EconML Python SDK, developed by the ALICE team at MSR New England, incorpora
comparison
estimation
estimation_iv
estimation_dynamic
inference
interpretability
references
Expand Down
2 changes: 1 addition & 1 deletion econml/_cate_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def effect(self, X=None, *, T0, T1):
"""
Calculate the heterogeneous treatment effect :math:`\\tau(X, T0, T1)`.

The effect is calculatred between the two treatment points
The effect is calculated between the two treatment points
conditional on a vector of features on a set of m test samples :math:`\\{T0_i, T1_i, X_i\\}`.
Since this class assumes a linear effect, only the difference between T0ᵢ and T1ᵢ
matters for this computation.
Expand Down
21 changes: 13 additions & 8 deletions econml/_ortho_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,8 @@ def fit(self, Y, T, X=None, W=None, Z=None, *, sample_weight=None, freq_weight=N
nuisances=nuisances,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var)
sample_var=sample_var,
groups=groups)

return self

Expand Down Expand Up @@ -770,18 +771,19 @@ def _fit_nuisances(self, Y, T, X=None, W=None, Z=None, sample_weight=None, group
return nuisances, fitted_models, fitted_inds, scores

def _fit_final(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None,
freq_weight=None, sample_var=None):
freq_weight=None, sample_var=None, groups=None):
self._ortho_learner_model_final.fit(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight,
freq_weight=freq_weight,
sample_var=sample_var))
sample_var=sample_var,
groups=groups))
self.score_ = None
if hasattr(self._ortho_learner_model_final, 'score'):
self.score_ = self._ortho_learner_model_final.score(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z,
nuisances=nuisances,
sample_weight=sample_weight)
)
sample_weight=sample_weight,
groups=groups))

def const_marginal_effect(self, X=None):
X, = check_input_arrays(X)
Expand Down Expand Up @@ -816,7 +818,7 @@ def effect_inference(self, X=None, *, T0=0, T1=1):
return super().effect_inference(X, T0=T0, T1=T1)
effect_inference.__doc__ = LinearCateEstimator.effect_inference.__doc__

def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
"""
Score the fitted CATE model on a new data set. Generates nuisance parameters
for the new data set based on the fitted nuisance models created at fit time.
Expand All @@ -840,6 +842,8 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
Instruments for each sample
sample_weight: optional(n,) vector or None (Default=None)
Weights for each samples
groups: (n,) vector, optional
All rows corresponding to the same group will be kept together during splitting.

Returns
-------
Expand All @@ -862,7 +866,7 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
for i, models_nuisances in enumerate(self._models_nuisance):
# for each model under cross fit setting
for j, mdl in enumerate(models_nuisances):
nuisance_temp = mdl.predict(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z))
nuisance_temp = mdl.predict(Y, T, **filter_none_kwargs(X=X, W=W, Z=Z, groups=groups))
if not isinstance(nuisance_temp, tuple):
nuisance_temp = (nuisance_temp,)

Expand All @@ -876,7 +880,8 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None):
nuisances[it] = np.mean(nuisances[it], axis=0)

return self._ortho_learner_model_final.score(Y, T, nuisances=nuisances,
**filter_none_kwargs(X=X, W=W, Z=Z, sample_weight=sample_weight))
**filter_none_kwargs(X=X, W=W, Z=Z,
sample_weight=sample_weight, groups=groups))

@property
def ortho_learner_model_final_(self):
Expand Down
Loading