Skip to content

Commit

Permalink
Merge pull request #182 from jakobrunge/developer
Browse files Browse the repository at this point in the history
fix bug regarding causal_effects.py and masking
  • Loading branch information
jakobrunge authored Mar 10, 2022
2 parents dab772d + 1c79b24 commit 826b4ec
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def run(self):
# Run the setup
setup(
name="tigramite",
version="5.0.0.3",
version="5.0.0.4",
packages=["tigramite", "tigramite.independence_tests", "tigramite.toymodels"],
license="GNU General Public License v3.0",
description="Tigramite causal discovery for time series",
Expand Down
3 changes: 3 additions & 0 deletions tigramite/causal_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,6 +1843,7 @@ def fit_total_effect(self,
conditions=self.listS,
tau_max=self.tau_max,
cut_off='max_lag_or_tau_max',
remove_missing_upto_maxlag=False,
return_data=False)

return self
Expand Down Expand Up @@ -1998,6 +1999,7 @@ def fit_wright_effect(self,
Y=[medy], X=[par], Z=oset,
tau_max=self.tau_max,
cut_off='max_lag_or_tau_max',
remove_missing_upto_maxlag=False,
return_data=False)
coeffs[medy][par] = fit_res[medy]['model'].coef_[0]
# print(mediators, par, medy, coeffs[medy][par])
Expand All @@ -2018,6 +2020,7 @@ def fit_wright_effect(self,
conditions=None,
tau_max=self.tau_max,
cut_off='max_lag_or_tau_max',
remove_missing_upto_maxlag=False,
return_data=False)

for ipar, par in enumerate(all_parents):
Expand Down
16 changes: 13 additions & 3 deletions tigramite/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def construct_array(self, X, Y, Z, tau_max,
return_cleaned_xyz=False,
do_checks=True,
cut_off='2xtau_max',
remove_missing_upto_maxlag=True,
verbosity=0):
"""Constructs array from variables X, Y, Z from data.
Expand Down Expand Up @@ -155,6 +156,9 @@ def construct_array(self, X, Y, Z, tau_max,
which uses the maximum of tau_max and the conditions, which is
useful to compare multiple models on the same sample. Last,
'max_lag' uses as much samples as possible.
remove_missing_upto_maxlag : bool, optional (default: True)
Whether to remove not only missing samples, but also all neighboring
samples up to max_lag (as given by cut_off).
verbosity : int, optional (default: 0)
Level of verbosity.
Expand Down Expand Up @@ -228,11 +232,17 @@ def construct_array(self, X, Y, Z, tau_max,
# slices that occur up to max_lag after
if self.missing_flag is not None:
missing_anywhere = np.any(np.isnan(self.values), axis=1)
for tau in range(max_lag+1):
if remove_missing_upto_maxlag:
for tau in range(max_lag+1):
if self.bootstrap is None:
use_indices[missing_anywhere[tau:T-max_lag+tau]] = 0
else:
use_indices[missing_anywhere[self.bootstrap - max_lag + tau]] = 0
else:
if self.bootstrap is None:
use_indices[missing_anywhere[tau:T-max_lag+tau]] = 0
use_indices[missing_anywhere] = 0
else:
use_indices[missing_anywhere[self.bootstrap - max_lag + tau]] = 0
use_indices[missing_anywhere[self.bootstrap]] = 0

# Use the mask override if needed
_use_mask = mask
Expand Down
5 changes: 5 additions & 0 deletions tigramite/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def get_general_fitted_model(self,
conditions=None,
tau_max=None,
cut_off='max_lag_or_tau_max',
remove_missing_upto_maxlag=True,
return_data=False):
"""Fit time series model.
Expand All @@ -108,6 +109,9 @@ def get_general_fitted_model(self,
sample. Other options are '2xtau_max', which guarantees that MCI
tests are all conducted on the same samples. Last, 'max_lag' uses
as much samples as possible.
remove_missing_upto_maxlag : bool, optional (default: True)
Whether to remove not only missing samples, but also all neighboring
samples up to max_lag (as given by cut_off).
return_data : bool, optional (default: False)
Whether to save the data array.
Expand Down Expand Up @@ -160,6 +164,7 @@ def get_general_fitted_model(self,
tau_max=self.tau_max,
mask_type=self.mask_type,
cut_off=self.cut_off,
remove_missing_upto_maxlag=remove_missing_upto_maxlag,
verbosity=self.verbosity)


Expand Down

0 comments on commit 826b4ec

Please sign in to comment.