Skip to content

Commit

Permalink
Address PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed May 7, 2021
1 parent 3b3850e commit 9dc967b
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 62 deletions.
104 changes: 43 additions & 61 deletions econml/solutions/causal_analysis/_causal_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sklearn.utils.validation import column_or_1d
from ...cate_interpreter import SingleTreeCateInterpreter, SingleTreePolicyInterpreter
from ...dml import LinearDML, CausalForestDML
from ...inference import NormalInferenceResults
from ...sklearn_extensions.linear_model import WeightedLasso
from ...sklearn_extensions.model_selection import GridSearchCVList
from ...utilities import _RegressionWrapper, inverse_onehot
Expand Down Expand Up @@ -128,7 +129,7 @@ def _first_stage_clf(X, y, *, make_regressor=False, automl=True):
RandomForestClassifier(
n_estimators=100, random_state=123),
GradientBoostingClassifier(random_state=123)],
param_grid_list=[{'C': [0.01, .1, 1, 10, 100]},
param_grid_list=[{'logisticregression__C': [0.01, .1, 1, 10, 100]},
{'max_depth': [3, 5],
'min_samples_leaf': [10, 50]},
{'n_estimators': [50, 100],
Expand Down Expand Up @@ -156,25 +157,6 @@ def _final_stage():
scoring='neg_mean_squared_error')


# TODO: make public and move to utilities
class _ConstantFeatures(TransformerMixin):
"""
Tranformer that ignores its input, outputting a single constant column
"""

def __init__(self, constant=1):
self.constant = 1

def fit(self, _):
return self

def transform(self, arr):
n, _ = arr.shape
return np.full((n, 1), self.constant)

def get_feature_names(self, _=None):
return ["constant"]

# simplification of sklearn's ColumnTransformer that encodes categoricals and passes through selected other columns
# but also supports get_feature_names with expected signature

Expand Down Expand Up @@ -260,15 +242,13 @@ class CausalAnalysis:
of the form theta(X)=<a, X> will be used, while 'forest' means that a forest model will be trained instead.
TODO. Add other options, such as {'automl'} for performing
model selection for the causal effect, or {'sparse_linear'} for using a debiased lasso. (post-MVP)
automl: bool, default True
Whether to automatically perform model selection over a variety of models
n_jobs: int, default -1
Degree of parallelism to use when training models via joblib.Parallel
"""

_results = namedtuple("_results", field_names=[
"feature_index", "feature_name", "feature_baseline", "feature_levels", "hinds",
"X_transformer", "W_transformer", "estimator", "global_inference", "d_t"])
"X_transformer", "W_transformer", "estimator", "global_inference"])

def __init__(self, feature_inds, categorical, heterogeneity_inds=None, feature_names=None, classification=False,
upper_bound_on_cat_expansion=5, nuisance_models='linear', heterogeneity_model='linear', n_jobs=-1):
Expand Down Expand Up @@ -342,8 +322,10 @@ def fit(self, X, y, warm_start=False):
if warm_start and self.nuisance_models != self.nuisance_models_:
warnings.warn("warm_start will be ignored since the nuisance models have changed"
f" from {self.nuisance_models_} to {self.nuisance_models} since the previous call to fit")
new_inds = train_inds
elif warm_start:
warm_start = False

# BUG: need to also train new model_y
if warm_start:
new_inds = [ind for ind in train_inds if (ind not in self._cache or
heterogeneity_inds[ind] != self._cache[ind][1].hinds)]
else:
Expand All @@ -354,12 +336,11 @@ def fit(self, X, y, warm_start=False):
# train the Y model

# perform model selection for the Y model using all X, not on a per-column basis
self._x_transform = ColumnTransformer([('encode',
OneHotEncoder(
drop='first', sparse=False),
self.categorical)],
remainder='passthrough')
allX = self._x_transform.fit_transform(X)
allX = ColumnTransformer([('encode',
OneHotEncoder(
drop='first', sparse=False),
self.categorical)],
remainder='passthrough').fit_transform(X)

if self.classification:
self._model_y = _first_stage_clf(
Expand Down Expand Up @@ -420,13 +401,8 @@ def process_feature(name, feat_ind):
model_t = (_first_stage_clf(WX, T, automl=self.nuisance_models == 'automl')
if discrete_treatment else _first_stage_reg(WX, T, automl=self.nuisance_models == 'automl'))

# For the global model, use a constant featurizer to fit the ATE
# Ideally, this could be PolynomialFeatures(degree=0, include_bias=True), but degree=0 is unsupported
# So instead we'll use our own class
featurizer = _ConstantFeatures()
featurizer = PolynomialFeatures(degree=1, include_bias=True)

# TODO: support other types of heterogeneity via an initializer arg
# e.g. 'forest' -> ForestDML
if self.heterogeneity_model == 'linear':
est = LinearDML(model_y=self._model_y,
model_t=model_t,
Expand All @@ -440,15 +416,22 @@ def process_feature(name, feat_ind):
model_t=model_t,
featurizer=featurizer,
discrete_treatment=discrete_treatment,
n_estimators=4000,
random_state=123)
est.tune(y, T, X=X_xf, W=W)
est.fit(y, T, X=X_xf, W=W, cache_values=True)

# effect doesn't depend on W, so only pass in first row
global_inference = est.const_marginal_effect_inference(X=X_xf[0:1])
# For the local model, change the featurizer to include X
est.featurizer = PolynomialFeatures(degree=1, include_bias=True)
est.refit_final()
# Prefer ate__inference to const_marginal_ate_inference(X) because it is doubly-robust and not conservative
if self.heterogeneity_model == 'forest' and discrete_treatment:
global_inference = est.ate__inference()
else:
# convert to NormalInferenceResults for consistency
inf = est.const_marginal_ate_inference(X=X_xf)
global_inference = NormalInferenceResults(d_t=inf.d_t, d_y=inf.d_y,
pred=inf.mean_point,
pred_stderr=inf.stderr_mean,
mean_pred_stderr=None,
inf_type='ate')

# Set the dictionary values shared between local and global summaries
if discrete_treatment:
Expand Down Expand Up @@ -482,7 +465,6 @@ def process_feature(name, feat_ind):
X_transformer=X_transformer,
W_transformer=W_transformer,
estimator=est,
d_t=d_t,
global_inference=global_inference)

return insights, result
Expand All @@ -492,16 +474,18 @@ def process_feature(name, feat_ind):
feature_names = X.columns
else:
feature_names = [f"x{i}" for i in range(X.shape[1])]
else:
feature_names = self.feature_names

self.feature_names_ = feature_names

# extract subset matching new columns
feature_names = _safe_indexing(feature_names, new_inds)
# extract subset of names matching new columns
new_feat_names = _safe_indexing(feature_names, new_inds)

cache_updates = dict(zip(new_inds,
joblib.Parallel(n_jobs=self.n_jobs,
verbose=1)(joblib.delayed(process_feature)(feat_name, feat_ind)
for feat_name, feat_ind in zip(feature_names, new_inds))))
for feat_name, feat_ind in zip(new_feat_names, new_inds))))

self._cache.update(cache_updates)

Expand Down Expand Up @@ -618,7 +602,11 @@ def make_dataframe(props):
names=["sample", "outcome", "feature", "feature_value"])
for lvl in index.levels:
if len(lvl) == 1:
index = index.droplevel(lvl.name)
if not isinstance(index, pd.MultiIndex):
# can't drop only level
index = pd.Index(["value"])
else:
index = index.droplevel(lvl.name)
return pd.DataFrame(to_include, index=index)

return self._summarize(summary=make_dataframe,
Expand Down Expand Up @@ -679,7 +667,8 @@ def global_causal_effect(self, alpha=0.05):
in the serialized dict.
"""
# a global inference indicates the effect of that one feature on the outcome
return self._pandas_summary(lambda res: res.global_inference, props=self._point_props(alpha), n=1)
return self._pandas_summary(lambda res: res.global_inference, props=self._point_props(alpha),
n=1, expand_arr=True)

def _global_causal_effect_dict(self, alpha=0.05):
"""
Expand All @@ -690,7 +679,7 @@ def _global_causal_effect_dict(self, alpha=0.05):
Only for serialization purposes to upload to AzureML
"""
return self._dict_summary(lambda res: res.global_inference, props=self._point_props(alpha),
kind='global', drop_sample=True)
kind='global', drop_sample=True, expand_arr=True)

def _cohort_effect_inference(self, Xtest):
assert np.ndim(Xtest) == 2 and np.shape(Xtest)[1] == self._d_x, (
Expand Down Expand Up @@ -794,7 +783,7 @@ def _local_causal_effect_dict(self, Xtest, alpha=0.05):
return self._dict_summary(self._local_effect_inference(Xtest), props=self._point_props(alpha),
kind='local')

def _check_feature_index(self, X, feature_index):
def _safe_result_index(self, X, feature_index):
assert hasattr(self, "_results"), "This instance has not yet been fitted"

assert np.ndim(X) == 2 and np.shape(X)[1] == self._d_x, (
Expand All @@ -810,7 +799,7 @@ def _check_feature_index(self, X, feature_index):
(result,) = results
return result

def whatif(self, X, Xnew, feature_index, y):
def whatif(self, X, Xnew, feature_index, y, alpha=0.05):
"""
Get counterfactual predictions when feature_index is changed to Xnew from its observational counterpart.
Expand All @@ -827,6 +816,7 @@ def whatif(self, X, Xnew, feature_index, y):
the string name if the input is a dataframe
y: array-like
Observed labels or outcome of a predictive model for baseline y values
"""

assert not self.classification, "What-if analysis cannot be applied to classification tasks"
Expand All @@ -840,7 +830,7 @@ def whatif(self, X, Xnew, feature_index, y):

T0 = _safe_indexing(X, feature_index, axis=1)
T1 = Xnew
result = self._check_feature_index(X, feature_index)
result = self._safe_result_index(X, feature_index)
inf = result.estimator.effect_inference(
X=result.X_transformer.transform(X), T0=T0, T1=T1)

Expand Down Expand Up @@ -882,7 +872,7 @@ def _whatif_dict(self, X, Xnew, feature_index, y):
def _tree(self, is_policy, Xtest, feature_index, *, treatment_cost=0,
max_depth=3, min_samples_leaf=2, min_impurity_decrease=1e-4, alpha=.1):

result = self._check_feature_index(Xtest, feature_index)
result = self._safe_result_index(Xtest, feature_index)
Xtest = result.X_transformer.transform(Xtest)

if result.feature_baseline is None:
Expand Down Expand Up @@ -1046,11 +1036,3 @@ def _heterogeneity_tree_string(self, Xtest, feature_index, *,
alpha=alpha)
return intrp.export_graphviz(feature_names=feature_names,
treatment_names=treatment_names)

@property
def cate_models_(self):
return [result.estimator for result in self._results]

@property
def X_transformers_(self):
return [result.X_transformer for result in self._results]
Loading

0 comments on commit 9dc967b

Please sign in to comment.