Skip to content

Commit

Permalink
Merge pull request #188 from jakobrunge/developer
Browse files Browse the repository at this point in the history
Developer
  • Loading branch information
jakobrunge authored Mar 16, 2022
2 parents fcbd87c + c8de437 commit 13d0372
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 45 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def run(self):
# Run the setup
setup(
name="tigramite",
version="5.0.0.7",
version="5.0.0.8",
packages=["tigramite", "tigramite.independence_tests", "tigramite.toymodels"],
license="GNU General Public License v3.0",
description="Tigramite causal discovery for time series",
Expand Down
127 changes: 83 additions & 44 deletions tigramite/causal_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from copy import deepcopy
from collections import defaultdict
from tigramite.models import Models
import struct

class CausalEffects():
r"""Causal effect estimation.
Expand Down Expand Up @@ -132,10 +133,12 @@ def __init__(self,
# print(self.graph.shape)
self._check_graph(self.graph)

anc_Y = self._get_ancestors(Y)
self.ancX = self._get_ancestors(X)
self.ancY = self._get_ancestors(Y)
self.ancS = self._get_ancestors(S)

# If X is not in anc(Y), then no causal link exists
if anc_Y.intersection(set(X)) == set():
if self.ancY.intersection(set(X)) == set():
self.no_causal_path = True
if self.verbosity > 0:
print("No causal path from X to Y exists.")
Expand Down Expand Up @@ -169,21 +172,24 @@ def __init__(self,
if check_SM_overlap and len(self.S.intersection(self.M)) > 0:
raise ValueError("Conditions S overlap with mediators M!")

descendants = self._get_descendants(self.Y.union(self.M))

self.desX = self._get_descendants(self.X)
self.desY = self._get_descendants(self.Y)
self.desM = self._get_descendants(self.M)
self.descendants = self.desY.union(self.desM)

# Define forb as X and descendants of YM
self.forbidden_nodes = descendants.union(self.X) #.union(S)
self.forbidden_nodes = self.descendants.union(self.X) #.union(S)

# Define valid ancestors
self.vancs = self._get_ancestors(list(self.X.union(self.Y).union(self.S))) - self.forbidden_nodes
self.vancs = self.ancX.union(self.ancY).union(self.ancS) - self.forbidden_nodes

if len(self.S.intersection(self._get_descendants(self.X))) > 0:
if self.verbosity > 0:
if self.verbosity > 0:
if len(self.S.intersection(self.desX)) > 0:
print("Warning: Potentially outside assumptions: Conditions S overlap with des(X)")

# Here only check if S overlaps with des(Y), leave the option that S
# contains variables in des(M) to the user
if len(self.S.intersection(self._get_descendants(self.Y))) > 0:
if len(self.S.intersection(self.desY)) > 0:
raise ValueError("Not identifiable: Conditions S overlap with des(Y)")

if self.verbosity > 0:
Expand Down Expand Up @@ -345,16 +351,16 @@ def check_XYS_paths(self):
oldX = self.X.copy()
oldY = self.Y.copy()

anc_Y = self._get_ancestors(self.Y)
anc_S = self._get_ancestors(self.S)
# anc_Y = self._get_ancestors(self.Y)
# anc_S = self._get_ancestors(self.S)

# Remove first from X those nodes with no causal path to Y or S
X = set([x for x in self.X if x in anc_Y.union(anc_S)])
X = set([x for x in self.X if x in self.ancY.union(self.ancS)])

# Remove from Y those nodes with no causal path from X
des_X = self._get_descendants(X)
# des_X = self._get_descendants(X)

Y = set([y for y in self.Y if y in des_X])
Y = set([y for y in self.Y if y in self.desX])

# Also require that all x in X have proper path to Y or S,
# that is, the first link goes out of x
Expand Down Expand Up @@ -596,27 +602,33 @@ def _find_adj(self, node, patterns, exclude=None, return_link=False):
# Find adjacencies going forward/contemp
for k, lag_ik in zip(*np.where(graph[i,:,lag_i,:])):
# print((k, lag_ik), graph[i,k,lag_i,lag_ik])
matches = [self._match_link(patt, graph[i,k,lag_i,lag_ik]) for patt in patterns]
if np.any(matches):
match = (k, -lag_ik)
if match not in exclude:
if return_link:
adj.append((graph[i,k,lag_i,lag_ik], match))
else:
adj.append(match)
# matches = [self._match_link(patt, graph[i,k,lag_i,lag_ik]) for patt in patterns]
# if np.any(matches):
for patt in patterns:
if self._match_link(patt, graph[i,k,lag_i,lag_ik]):
match = (k, -lag_ik)
if match not in exclude:
if return_link:
adj.append((graph[i,k,lag_i,lag_ik], match))
else:
adj.append(match)
break


# Find adjacencies going backward/contemp
for k, lag_ki in zip(*np.where(graph[:,i,:,lag_i])):
# print((k, lag_ki), graph[k,i,lag_ki,lag_i])
matches = [self._match_link(self._reverse_link(patt), graph[k,i,lag_ki,lag_i]) for patt in patterns]
if np.any(matches):
match = (k, -lag_ki)
if match not in exclude:
if return_link:
adj.append((self._reverse_link(graph[k,i,lag_ki,lag_i]), match))
else:
adj.append(match)
# matches = [self._match_link(self._reverse_link(patt), graph[k,i,lag_ki,lag_i]) for patt in patterns]
# if np.any(matches):
for patt in patterns:
if self._match_link(self._reverse_link(patt), graph[k,i,lag_ki,lag_i]):
match = (k, -lag_ki)
if match not in exclude:
if return_link:
adj.append((self._reverse_link(graph[k,i,lag_ki,lag_i]), match))
else:
adj.append(match)
break

adj = list(set(adj))
return adj
Expand All @@ -634,7 +646,7 @@ def _is_match(self, nodei, nodej, pattern_ij):
return ((tauij >= 0 and self._match_link(pattern_ij, graph[i, j, tauij])) or
(tauij < 0 and self._match_link(self._reverse_link(pattern_ij), graph[j, i, abs(tauij)])))


# @profile
def _get_children(self, varlag):
"""Returns set of children (varlag --> ...) for (lagged) varlag."""
if self.possible:
Expand Down Expand Up @@ -745,6 +757,7 @@ def _get_descendants_stationary_graph(self, W, max_lag):

return descendants

# @profile
def _get_descendants(self, W):
"""Get descendants of nodes in W up to time t.
Expand Down Expand Up @@ -1079,6 +1092,7 @@ def _get_latent_projection_graph(self, stationary=False):

return aux_graph

# @profile
def _check_path(self,
# graph,
start, end,
Expand Down Expand Up @@ -1280,7 +1294,7 @@ def _check_path(self,
# print("Separated")
return False


# @profile
def get_optimal_set(self,
alternative_conditions=None,
minimize=False,
Expand Down Expand Up @@ -1314,9 +1328,17 @@ def get_optimal_set(self,
vancs = self.vancs.copy()
else:
S = alternative_conditions
vancs = self._get_ancestors(list(self.X.union(self.Y).union(S))) - self.forbidden_nodes
newancS = self._get_ancestors(S)
self.vancs = self.ancX.union(self.ancY).union(newancS) - self.forbidden_nodes

# vancs = self._get_ancestors(list(self.X.union(self.Y).union(S))) - self.forbidden_nodes

# descendants = self._get_descendants(self.Y.union(self.M))

# Sufficient condition for non-identifiability
if len(self.X.intersection(self.descendants)) > 0:
return False # raise ValueError("Not identifiable: Overlap between X and des(M)")

descendants = self._get_descendants(self.Y.union(self.M))

##
## Construct O-set
Expand Down Expand Up @@ -1409,8 +1431,9 @@ def get_optimal_set(self,
# if-statements of the construction algorithm, but for
# multivariate X there might be further cases... Hence,
# we here explicitely check validity
if self._check_validity(list(Oset_S)) is False:
return False
# if len(self.X) > 1:
# if self._check_validity(list(Oset_S)) is False:
# return False

if return_separate_sets:
return parents, colliders, collider_parents, S
Expand Down Expand Up @@ -1769,7 +1792,7 @@ def _get_causal_paths(self, source_nodes, target_nodes,

return all_causal_paths


# @profile
def fit_total_effect(self,
dataframe,
estimator,
Expand Down Expand Up @@ -1851,6 +1874,7 @@ def fit_total_effect(self,

return self

# @profile
def predict_total_effect(self,
intervention_data,
conditions_data=None,
Expand Down Expand Up @@ -1895,6 +1919,7 @@ def predict_total_effect(self,

return effect

# @profile
def fit_wright_effect(self,
dataframe,
mediation=None,
Expand Down Expand Up @@ -1953,7 +1978,7 @@ def fit_wright_effect(self,
mask_type=mask_type,
verbosity=self.verbosity)

mediators = self.get_mediators(start=self.X, end=self.Y)
mediators = self.M # self.get_mediators(start=self.X, end=self.Y)

if mediation == 'direct':
causal_paths = {}
Expand Down Expand Up @@ -2073,7 +2098,7 @@ def predict(self, X):
self.model.fit_results = fit_results
return self


# @profile
def predict_wright_effect(self,
intervention_data,
pred_params=None,
Expand Down Expand Up @@ -2222,7 +2247,7 @@ def _get_minmax_lag(links):
import sklearn
from sklearn.linear_model import LinearRegression

T = 10000
T = 1000
def lin_f(x): return x
auto_coeff = 0.3
coeff = 2.
Expand Down Expand Up @@ -2254,18 +2279,32 @@ def lin_f(x): return x
verbosity=1)

# Optimal adjustment set (is used by default)
print(causal_effects.get_optimal_set())
# print(causal_effects.get_optimal_set())

# # Fit causal effect model from observational data
# causal_effects.fit_total_effect(
# dataframe=dataframe,
# # mask_type='y',
# estimator=LinearRegression(),
# )

# # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
# dox_vals = np.linspace(0., 1., 5)
# intervention_data = dox_vals.reshape(len(dox_vals), len(X))
# pred_Y = causal_effects.predict_total_effect(
# intervention_data=intervention_data)
# print(pred_Y)

# Fit causal effect model from observational data
causal_effects.fit_total_effect(
causal_effects.fit_wright_effect(
dataframe=dataframe,
# mask_type='y',
estimator=LinearRegression(),
# estimator=LinearRegression(),
)

# Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
dox_vals = np.linspace(0., 1., 5)
intervention_data = dox_vals.reshape(len(dox_vals), len(X))
pred_Y = causal_effects.predict_total_effect(
pred_Y = causal_effects.predict_wright_effect(
intervention_data=intervention_data)
print(pred_Y)
2 changes: 2 additions & 0 deletions tigramite/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(self,
self.tau_max = None
self.fit_results = None

# @profile
def get_general_fitted_model(self,
Y, X, Z=None,
conditions=None,
Expand Down Expand Up @@ -200,6 +201,7 @@ def get_general_fitted_model(self,
self.fit_results = fit_results
return fit_results

# @profile
def get_general_prediction(self,
intervention_data,
conditions_data=None,
Expand Down

0 comments on commit 13d0372

Please sign in to comment.