Skip to content

Commit

Permalink
Merge pull request #418 from martin-rabel/mediation_update
Browse files Browse the repository at this point in the history
Mediation update
  • Loading branch information
jakobrunge authored Jul 17, 2024
2 parents 9039f95 + ab537ab commit 8a86147
Show file tree
Hide file tree
Showing 4 changed files with 1,250 additions and 104 deletions.
190 changes: 181 additions & 9 deletions tests/test_causal_mediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def test_graph_nde_fully_binary():
fit_setup = mediation.FitSetup()

tau_max = 2
estimator = mediation.NaturalEffects_GraphMediation(graph, graph_type, tau_max, fit_setup, world.Observables(),
estimator = mediation.NaturalEffects_GraphMediation(graph, graph_type, tau_max, fit_setup, obs,
effect_source=X, effect_target=M,
blocked_mediators="all", adjustment_set="auto",
fall_back_to_total_effect=True)
Expand Down Expand Up @@ -645,9 +645,9 @@ def test_interventions(med_vars, env, world, world0, world1):
# in the "real" world A is noise, thus not always 0
assert np.any(world.Observables()[med_vars.A] != 0.0)
# After intervention A=0, a should always be 0
assert np.alltrue(world0.Observables()[med_vars.A] == 0.0)
assert np.all(world0.Observables()[med_vars.A] == 0.0)
# After intervention A=0, a should always be 1
assert np.alltrue(world1.Observables()[med_vars.A] == 1.0)
assert np.all(world1.Observables()[med_vars.A] == 1.0)


@pytest.fixture(scope="module")
Expand All @@ -660,11 +660,11 @@ def cfworld(med_vars, env, model, world, world0, world1):


def test_cfworld(med_vars, cfworld, world, world0, world1):
assert np.alltrue(cfworld.Observables()[med_vars.A] == world1.Observables()[med_vars.A] )
assert np.alltrue(cfworld.Observables()[med_vars.M] == world0.Observables()[med_vars.M] )
assert np.alltrue(cfworld.data[med_vars.A.Noise()] == world.data[med_vars.A.Noise()] )
assert np.alltrue(cfworld.data[med_vars.M.Noise()] == world.data[med_vars.M.Noise()] )
assert np.alltrue(cfworld.data[med_vars.Y.Noise()] == world.data[med_vars.Y.Noise()] )
assert np.all(cfworld.Observables()[med_vars.A] == world1.Observables()[med_vars.A] )
assert np.all(cfworld.Observables()[med_vars.M] == world0.Observables()[med_vars.M] )
assert np.all(cfworld.data[med_vars.A.Noise()] == world.data[med_vars.A.Noise()] )
assert np.all(cfworld.data[med_vars.M.Noise()] == world.data[med_vars.M.Noise()] )
assert np.all(cfworld.data[med_vars.Y.Noise()] == world.data[med_vars.Y.Noise()] )



Expand All @@ -674,9 +674,71 @@ def test_cfworld(med_vars, cfworld, world, world0, world1):
-------------------------------------------------------------------------------------------"""


def test_tutorial_example0():
graph = np.array([[['', '-->', ''],
['', '', ''],
['', '', '']],
[['', '-->', ''],
['', '-->', ''],
['-->', '', '-->']],
[['', '', ''],
['<--', '', ''],
['', '-->', '']]], dtype='<U3')

X = [(1,-2)]
Y = [(2,0)]
causal_effects = CausalMediation(graph, graph_type='stationary_dag', X=X, Y=Y,
S=None, # (currently S must be None)
hidden_variables=None, # (currently hidden must be None)
verbosity=1)
var_names = ['$X^0$', '$X^1$', '$X^2$']

opt = causal_effects.get_optimal_set()

from tigramite import data_processing as pp
from tigramite.toymodels import structural_causal_processes as toys

coeff = .5
direct_eff = 0.5
def lin_f(x): return x
links_coeffs = {
0: [((0, -1), coeff, lin_f), ((1, -1), coeff, lin_f)],
1: [((1, -1), coeff, lin_f),],
2: [((2, -1), coeff, lin_f), ((1, 0), coeff, lin_f), ((1,-2), direct_eff, lin_f)],
}
# Observational data
T = 1000
data, nonstat = toys.structural_causal_process(links_coeffs, T=T, noises=None, seed=42)
normalization = []
data_normalized = np.empty_like(data)
for v in range(0,3):
m = np.std(data[:,v])
normalization.append(m)
data_normalized[:,v] = data[:,v] / m
dataframe = pp.DataFrame(data, var_names=var_names)
dataframe_normalized = pp.DataFrame(data_normalized, var_names=var_names)
fit_setup = mediation.FitSetup(mediation.FitProvider_Continuous_Default.UseSklearn(20))

# unnormalized data
causal_effects.fit_natural_direct_effect(dataframe, blocked_mediators='all',
mixed_data_estimator=fit_setup).PrintInfo()

nde_est = causal_effects.predict_natural_direct_effect(0.0, 1.0)

# normalized data
causal_effects.fit_natural_direct_effect(dataframe_normalized, blocked_mediators='all',
mixed_data_estimator=fit_setup)

nde_est_from_normalized = causal_effects.predict_natural_direct_effect(0.0, 1.0) * normalization[2] / normalization[1]

# print results
print( f"Estimate of the NDE is:\n{nde_est} from unnormalized data, "
+ f"\n{nde_est_from_normalized} from normalized data,\nground-truth is {direct_eff}." )

assert 0.3 < nde_est_from_normalized < 0.6


def test_tutorial_example():
def test_tutorial_example():
graph = np.array([[['', '-->', ''],
['', '', ''],
['', '', '']],
Expand Down Expand Up @@ -728,3 +790,113 @@ def lin_f(x): return x
print(f"Estimate of the NDE is {nde_est}, ground-truth is {direct_eff}.")

assert 0.8 < nde_est / direct_eff < 1.1




class FitProvider_Continous_Filtered:
def __init__(self, underlying_fit_provider, filter_to_use):
self.filter = filter_to_use
self.underlying_fit_provider = underlying_fit_provider
def Get_Fit_Continuous(self,x_train,y_train):
return self.underlying_fit_provider.Get_Fit_Continuous(*self.filter.apply(x_train,y_train))

class FitProvider_Density_Filtered:
def __init__(self, underlying_fit_provider, filter_to_use):
self.filter = filter_to_use
self.underlying_fit_provider = underlying_fit_provider
def Get_Fit_Density(self, x_train):
return self.underlying_fit_provider.Get_Fit_Density(self.filter.apply(x_train))

class FilterMissingValues:
def __init__(self, missing_value_flag):
self.missing_value_flag = missing_value_flag
def apply(self,x,y=None):
missing_in_any_x = np.any( x==self.missing_value_flag, axis=1 )
if y is None:
valid = np.logical_not( missing_in_any_x )
return x[valid]
else:
missing_in_y = ( y==self.missing_value_flag )
valid = np.logical_not( np.logical_or(missing_in_any_x, missing_in_y) )
return x[valid], y[valid]

def apply_filter_to_all_inputs(fit_setup, filter_to_apply):
# Assume the fit_setup can be contructed from map & density fit and has corresponding members
# (for all implementations based on the FitSetup class in the mediation-module
# of tigramite this is the case; see tutorial on mediation, appendix B)
return fit_setup.__class__(
fit_map=FitProvider_Continous_Filtered(fit_setup.fit_map, filter_to_apply),
fit_density=FitProvider_Density_Filtered(fit_setup.fit_density, filter_to_apply),
)


def test_tutorial_example_custom_fit():
graph = np.array([[['', '-->', ''],
['', '', ''],
['', '', '']],
[['', '-->', ''],
['', '-->', ''],
['-->', '', '-->']],
[['', '', ''],
['<--', '', ''],
['', '-->', '']]], dtype='<U3')

X = [(1,-2)]
Y = [(2,0)]
var_names = ['$X^0$', '$X^1$', '$X^2$']

from tigramite import data_processing as pp
from tigramite.toymodels import structural_causal_processes as toys

coeff = .5
direct_eff = 0.5
def lin_f(x): return x
links_coeffs = {
0: [((0, -1), coeff, lin_f), ((1, -1), coeff, lin_f)],
1: [((1, -1), coeff, lin_f),],
2: [((2, -1), coeff, lin_f), ((1, 0), coeff, lin_f), ((1,-2), direct_eff, lin_f)],
}
# Observational data
T = 1000
data, nonstat = toys.structural_causal_process(links_coeffs, T=T, noises=None, seed=None)
normalization = []
data_normalized = np.empty_like(data)
for v in range(0,3):
m = np.std(data[:,v])
normalization.append(m)
data_normalized[:,v] = data[:,v] / m
dataframe = pp.DataFrame(data, var_names=var_names)
dataframe_normalized = pp.DataFrame(data_normalized, var_names=var_names)


seed = 12345
gap_count = 10
gap_min_len = 10
gap_max_len = 20
rng = np.random.default_rng(seed)
var_idx = rng.integers(0, data_normalized.shape[1], gap_count)
offset = rng.integers(0, data_normalized.shape[0]-gap_max_len, gap_count)
missing_count = rng.integers(10, 20, gap_count)

modified_data = data_normalized
for gap in range(gap_count):
modified_data[offset[gap]:offset[gap]+missing_count[gap], var_idx] = 999

fit_setup = mediation.FitSetup(mediation.FitProvider_Continuous_Default.UseSklearn(20))
fit_setup2 = apply_filter_to_all_inputs(fit_setup, FilterMissingValues(999))
dataframe_unmarked_missing = pp.DataFrame(data_normalized, var_names=var_names) #missing_flag=999)

causal_effects = CausalMediation(graph, graph_type='stationary_dag', X=X, Y=Y,
S=None, # (currently S must be None)
hidden_variables=None, # (currently hidden must be None)
verbosity=1)
# normalized data
causal_effects.fit_natural_direct_effect(dataframe_unmarked_missing, blocked_mediators='all',
mixed_data_estimator=fit_setup2, # set the new fit_setup
enable_dataframe_based_preprocessing=False)

nde_est = causal_effects.predict_natural_direct_effect(0.0, 1.0) * normalization[2] / normalization[1]

# print results
print( f"Estimate of the NDE is:\n{nde_est} with missing values,\nground-truth is {direct_eff}." )
46 changes: 32 additions & 14 deletions tigramite/causal_mediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def _Fit_Enriched_TransferMatrix(self, x_continuous,
for cY in range(y_category_count):
# Use Bayes' Theorem to avoid fitting densities conditional on continuous variables
filter_both = np.logical_and(x_categorical == cX, y_categorical == cY)
assert len(filter_both.shape) == 1
p_y_given_discrete_x = np.count_nonzero(filter_both) / normalization
p_continuous_x_given_discrete_x_and_y = self.fit_density.Get_Fit_Density(
x_continuous[:, filter_both].T)
Expand Down Expand Up @@ -1050,21 +1051,25 @@ class NaturalEffects_GraphMediation:
however, for comparison to other estimates, using this option might yield more consistent results.
_internal_provide_cfx : *None* or tigramite.CausalEffects
Set to None. Used when called from CausalMediation, which already has a causal-effects class.
enable_dataframe_based_preprocessing : bool
Enable (and enforce) data-preprocessing through the tigramite::dataframe, makes missing-data
and other features available to the mediation analysis. Custom (just in time) handling
of missing data might be more sample-efficient.
"""

def __init__(self, graph, graph_type, tau_max, fit_setup, observations_data, effect_source, effect_target,
blocked_mediators="all", adjustment_set="auto", only_check_validity=False,
fall_back_to_total_effect=False, _internal_provide_cfx=None):
fall_back_to_total_effect=False, _internal_provide_cfx=None, enable_dataframe_based_preprocessing=True):

data = toy_setup.DataHandler(observations_data)
data = toy_setup.DataHandler(observations_data, dataframe_based_preprocessing=enable_dataframe_based_preprocessing)

self.Source = data.GetVariableAuto(effect_source)
self.Target = data.GetVariableAuto(effect_target)
self.Source = data.GetVariableAuto(effect_source, "Source")
self.Target = data.GetVariableAuto(effect_target, "Target")

if blocked_mediators != "all":
blocked_mediators = data.GetVariablesAuto(blocked_mediators)
blocked_mediators = data.GetVariablesAuto(blocked_mediators, "Mediator")
if adjustment_set != "auto":
adjustment_set = data.GetVariablesAuto(adjustment_set)
adjustment_set = data.GetVariablesAuto(adjustment_set, "Adjustment")

X = data[self.Source]
Y = data[self.Target]
Expand All @@ -1078,7 +1083,7 @@ def __init__(self, graph, graph_type, tau_max, fit_setup, observations_data, eff
all_mediators = blocked_mediators == "all"
if all_mediators:
M = cfx_xy.M
blocked_mediators = data.ReverseLookupMulti(M)
blocked_mediators = data.ReverseLookupMulti(M, "Mediator")
else:
M = data[blocked_mediators]
if not set(M) <= set(cfx_xy.M):
Expand Down Expand Up @@ -1107,7 +1112,7 @@ def valid(S):
# fall back to adjust, which should work if any single adjustmentset works
Z = cfx_xy._get_adjust_set()

adjustment_set = data.ReverseLookupMulti(Z)
adjustment_set = data.ReverseLookupMulti(Z, "Adjustment")

else:
Z = data[adjustment_set]
Expand Down Expand Up @@ -1135,6 +1140,11 @@ def valid(S):
"X -> Y. If such a set exists, Perkovic's Adjust(X,Y) is valid, which is tried as "
"fallback if adjustment-set='auto' is used.")


# lock in mediators and adjustment for preprocessing
self.BlockedMediators = data.ReverseLookupMulti(M)
self.AdjustmentSet = data.ReverseLookupMulti(Z)

# ----- STORE RESULTS ON INSTANCE -----
self.X, self.sources = data.Get("Source", [X], tau_max=tau_max)
self.X = self.X[0] # currently univariate anyway
Expand All @@ -1146,15 +1156,13 @@ def valid(S):
self.M_ids = None
self.mediators = {}

self.BlockedMediators = data.ReverseLookupMulti(M)

if len(Z) > 0:
self.Z_ids, self.adjustment = data.Get("Adjustment", Z, tau_max=tau_max)
else:
self.Z_ids = None
self.adjustment = {}

self.AdjustmentSet = data.ReverseLookupMulti(Z)

self.fit_setup = fit_setup
self._E_Y_XMZ = None
Expand Down Expand Up @@ -1294,7 +1302,12 @@ def _NDE_categorical_target_full_density(self, cf_x, reference_x):
# and they are purely categorical, then the mapping (M u Z) -> P_Y
# has finite image, treating it as categorical gives better results

labels_y, transformed_y = np.unique(p_y_values, return_inverse=True, axis=0)
# different numpy-versions behave differently wrt this call:
# https://numpy.org/devdocs/release/2.0.0-notes.html#np-unique-return-inverse-shape-for-multi-dimensional-inputs
# see also https://github.com/numpy/numpy/issues/26738
labels_y, transformed_y_numpy_version_dependent = np.unique(p_y_values, return_inverse=True, axis=0)
transformed_y = transformed_y_numpy_version_dependent.squeeze()

P_Y = toy_setup.CategoricalVariable(categories=labels_y)
P_P_Y_xz = self.fit_setup.Fit({**self.sources, **self.adjustment}, {P_Y: transformed_y})

Expand Down Expand Up @@ -1446,7 +1459,8 @@ def __init__(self, graph, graph_type, X, Y, S=None, hidden_variables=None, verbo

def fit_natural_direct_effect(self, dataframe, mixed_data_estimator=FitSetup(),
blocked_mediators='all', adjustment_set='auto',
use_mediation_impl_for_total_effect_fallback=False):
use_mediation_impl_for_total_effect_fallback=False,
enable_dataframe_based_preprocessing=True):
"""Fit a natural direct effect.
Parameters
Expand All @@ -1467,7 +1481,11 @@ def fit_natural_direct_effect(self, dataframe, mixed_data_estimator=FitSetup(),
use_mediation_impl_for_total_effect_fallback : bool
If True, if no mediators are blocked, use mediation implementation to estimate the total effect.
In this case, estimating the total effect through the 'Causal Effects' class might be easier,
however, for comparison to other estimates, using this option might yield more consistent results.
however, for comparison to other estimates, using this option might yield more consistent results.
enable_dataframe_based_preprocessing : bool
Enable (and enforce) data-preprocessing through the tigramite::dataframe, makes missing-data
and other features available to the mediation analysis. Custom (just in time) handling
of missing data might be more sample-efficient.
Returns
-------
Expand All @@ -1491,7 +1509,7 @@ def fit_natural_direct_effect(self, dataframe, mixed_data_estimator=FitSetup(),
effect_source=source, effect_target=target,
blocked_mediators=self.BlockedMediators, adjustment_set=adjustment_set, only_check_validity=False,
fall_back_to_total_effect=use_mediation_impl_for_total_effect_fallback,
_internal_provide_cfx=self)
_internal_provide_cfx=self, enable_dataframe_based_preprocessing=enable_dataframe_based_preprocessing)
# return a NDE_Graph Estimator, but also remember it for predict_nde
return self.MediationEstimator

Expand Down
Loading

0 comments on commit 8a86147

Please sign in to comment.