Skip to content

Commit

Permalink
Switch causal analysis policy tree output format
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Jun 26, 2021
1 parent b1a7f44 commit b4d5716
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 26 deletions.
34 changes: 29 additions & 5 deletions econml/solutions/causal_analysis/_causal_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,29 @@ def recurse(node_id):
return recurse(0)


class _PolicyOutput:
"""
A type encapsulating various information related to a learned policy.
Attributes
----------
tree_dictionary:dict
The policy tree represented as a dictionary,
policy_value:float
The average value of applying the recommended policy (over using the control),
always_treat:dict of string to float
A dictionary mapping each non-control treatment to the value of always treating with it (over control),
control_name:string
The name of the control treatment
"""

def __init__(self, tree_dictionary, policy_value, always_treat, control_name):
self.tree_dictionary = tree_dictionary
self.policy_value = policy_value
self.always_treat = always_treat
self.control_name = control_name


# named tuple type for storing results inside CausalAnalysis class;
# must be lifted to module level to enable pickling
_result = namedtuple("_result", field_names=[
Expand Down Expand Up @@ -1291,10 +1314,7 @@ def _policy_tree_output(self, Xtest, feature_index, *, treatment_costs=0,
Returns
-------
tree : tuple of string, float, list of float
The policy tree represented as a graphviz string,
the value of applying the recommended policy (over never treating),
the value of always treating (over never treating) for each non-control treatment
output : _PolicyOutput
"""

(intrp, feature_names, treatment_names,
Expand All @@ -1307,7 +1327,11 @@ def _policy_tree_output(self, Xtest, feature_index, *, treatment_costs=0,

def policy_data(tree, node_id):
return {'treatment': treatment_names[np.argmax(tree.value[node_id])]}
return _tree_interpreter_to_dict(intrp, feature_names, policy_data), policy_val, always_trt.tolist()
return _PolicyOutput(_tree_interpreter_to_dict(intrp, feature_names, policy_data),
policy_val,
{treatment_names[i + 1]: val
for (i, val) in enumerate(always_trt.tolist())},
treatment_names[0])

# TODO: it seems like it would be better to just return the tree itself rather than plot it;
# however, the tree can't store the feature and treatment names we compute here...
Expand Down
67 changes: 46 additions & 21 deletions econml/tests/test_causal_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,18 @@ def test_basic_array(self):
# Make sure we handle continuous, binary, and multi-class treatments
# For multiple discrete treatments, one "always treat" value per non-default treatment
for (idx, length) in [(0, 1), (1, 1), (2, 1), (3, 2)]:
_, policy_val, always_trt = ca._policy_tree_output(X, idx)
assert isinstance(always_trt, list)
pto = ca._policy_tree_output(X, idx)
policy_val = pto.policy_value
always_trt = pto.always_treat
assert isinstance(pto.control_name, str)
assert isinstance(always_trt, dict)
assert np.array(policy_val).shape == ()
assert np.array(always_trt).shape == (length,)
assert len(always_trt) == length
for val in always_trt.values():
assert np.array(val).shape == ()

# policy value should exceed always treating with any treatment
assert_less_close(always_trt, policy_val)
assert_less_close(np.array(list(always_trt.values())), policy_val)

# global shape is (d_y, sum(d_t))
assert glo_point_est.shape == coh_point_est.shape == (1, 5)
Expand Down Expand Up @@ -149,13 +154,18 @@ def test_basic_pandas(self):
# Make sure we handle continuous, binary, and multi-class treatments
# For multiple discrete treatments, one "always treat" value per non-default treatment
for (idx, length) in [(0, 1), (1, 1), (2, 1), (3, 2)]:
_, policy_val, always_trt = ca._policy_tree_output(X, inds[idx])
assert isinstance(always_trt, list)
pto = ca._policy_tree_output(X, inds[idx])
policy_val = pto.policy_value
always_trt = pto.always_treat
assert isinstance(pto.control_name, str)
assert isinstance(always_trt, dict)
assert np.array(policy_val).shape == ()
assert np.array(always_trt).shape == (length,)
assert len(always_trt) == length
for val in always_trt.values():
assert np.array(val).shape == ()

# policy value should exceed always treating with any treatment
assert_less_close(always_trt, policy_val)
assert_less_close(np.array(list(always_trt.values())), policy_val)

if not classification:
# ExitStack can be used as a "do nothing" ContextManager
Expand Down Expand Up @@ -220,13 +230,18 @@ def test_automl_first_stage(self):
# Make sure we handle continuous, binary, and multi-class treatments
# For multiple discrete treatments, one "always treat" value per non-default treatment
for (idx, length) in [(0, 1), (1, 1), (2, 1), (3, 2)]:
_, policy_val, always_trt = ca._policy_tree_output(X, idx)
assert isinstance(always_trt, list)
pto = ca._policy_tree_output(X, idx)
policy_val = pto.policy_value
always_trt = pto.always_treat
assert isinstance(pto.control_name, str)
assert isinstance(always_trt, dict)
assert np.array(policy_val).shape == ()
assert np.array(always_trt).shape == (length,)
assert len(always_trt) == length
for val in always_trt.values():
assert np.array(val).shape == ()

# policy value should exceed always treating with any treatment
assert_less_close(always_trt, policy_val)
assert_less_close(np.array(list(always_trt.values())), policy_val)

# global shape is (d_y, sum(d_t))
assert glo_point_est.shape == coh_point_est.shape == (1, 5)
Expand Down Expand Up @@ -328,13 +343,18 @@ def test_final_models(self):
# Make sure we handle continuous, binary, and multi-class treatments
# For multiple discrete treatments, one "always treat" value per non-default treatment
for (idx, length) in [(0, 1), (1, 1), (2, 1), (3, 2)]:
_, policy_val, always_trt = ca._policy_tree_output(X, idx)
assert isinstance(always_trt, list)
pto = ca._policy_tree_output(X, idx)
policy_val = pto.policy_value
always_trt = pto.always_treat
assert isinstance(pto.control_name, str)
assert isinstance(always_trt, dict)
assert np.array(policy_val).shape == ()
assert np.array(always_trt).shape == (length,)
assert len(always_trt) == length
for val in always_trt.values():
assert np.array(val).shape == ()

# policy value should exceed always treating with any treatment
assert_less_close(always_trt, policy_val)
assert_less_close(np.array(list(always_trt.values())), policy_val)

if not classification:
# ExitStack can be used as a "do nothing" ContextManager
Expand Down Expand Up @@ -400,13 +420,18 @@ def test_forest_with_pandas(self):
# Make sure we handle continuous, binary, and multi-class treatments
# For multiple discrete treatments, one "always treat" value per non-default treatment
for (idx, length) in [(0, 1), (1, 1), (2, 1), (3, 2)]:
_, policy_val, always_trt = ca._policy_tree_output(X, inds[idx])
assert isinstance(always_trt, list)
pto = ca._policy_tree_output(X, inds[idx])
policy_val = pto.policy_value
always_trt = pto.always_treat
assert isinstance(pto.control_name, str)
assert isinstance(always_trt, dict)
assert np.array(policy_val).shape == ()
assert np.array(always_trt).shape == (length,)
assert len(always_trt) == length
for val in always_trt.values():
assert np.array(val).shape == ()

# policy value should exceed always treating with any treatment
assert_less_close(always_trt, policy_val)
assert_less_close(np.array(list(always_trt.values())), policy_val)

def test_warm_start(self):
for classification in [True, False]:
Expand Down Expand Up @@ -455,7 +480,7 @@ def test_empty_hinds(self):
eff = ca.global_causal_effect(alpha=0.05)
eff = ca.local_causal_effect(X_df, alpha=0.05)
for ind in feat_inds:
tree, val, always_trt = ca._policy_tree_output(X_df, ind)
pto = ca._policy_tree_output(X_df, ind)

def test_can_serialize(self):
import pickle
Expand Down

0 comments on commit b4d5716

Please sign in to comment.