diff --git a/setup.py b/setup.py index 94e4c22c..af548baf 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def run(self): # Run the setup setup( name="tigramite", - version="5.0.1.4", + version="5.0.1.5", packages=["tigramite", "tigramite.independence_tests", "tigramite.toymodels"], license="GNU General Public License v3.0", description="Tigramite causal discovery for time series", diff --git a/tigramite/causal_effects.py b/tigramite/causal_effects.py index 5e878f53..131b71bf 100644 --- a/tigramite/causal_effects.py +++ b/tigramite/causal_effects.py @@ -1882,6 +1882,7 @@ def predict_total_effect(self, intervention_data, conditions_data=None, pred_params=None, + return_further_pred_results=False, ): """Predict effect of intervention with fitted model. @@ -1895,6 +1896,9 @@ def predict_total_effect(self, Numpy array of shape (time, len(S)) that contains the S=s values. pred_params : dict, optional Optional parameters passed on to sklearn prediction function. + return_further_pred_results : bool, optional (default: False) + In case the predictor class returns more than just the expected value, + the entire results can be returned. Returns ------- @@ -1918,7 +1922,8 @@ def predict_total_effect(self, effect = self.model.get_general_prediction( intervention_data=intervention_data, conditions_data=conditions_data, - pred_params=pred_params) + pred_params=pred_params, + return_further_pred_results=return_further_pred_results) return effect @@ -2074,6 +2079,8 @@ def fit_wright_effect(self, effect[(x, y)] += effect_here + # Make fitted coefficients available as attribute + self.coeffs = coeffs # Modify and overwrite variables in self.model self.model.Y = self.listY diff --git a/tigramite/independence_tests/gpdc.py b/tigramite/independence_tests/gpdc.py index 337f7fd6..9cba375f 100644 --- a/tigramite/independence_tests/gpdc.py +++ b/tigramite/independence_tests/gpdc.py @@ -5,7 +5,7 @@ # License: GNU General Public License v3.0 from __future__ import print_function -import json, warnings +import json, warnings, os, pathlib import numpy as np try: from importlib import metadata @@ -14,7 +14,7 @@ try: import dcor from sklearn import gaussian_process - with open('../versions.py', 'r') as vfile: + with open(pathlib.Path(os.path.dirname(__file__)) / '../../versions.py', 'r') as vfile: packages = json.loads(vfile.read())['all'] packages = dict(map(lambda s: s.split('>='), packages)) if metadata.version('dcor') < packages['dcor']: diff --git a/tigramite/independence_tests/gpdc_torch.py b/tigramite/independence_tests/gpdc_torch.py index 9e58264b..7bebc492 100644 --- a/tigramite/independence_tests/gpdc_torch.py +++ b/tigramite/independence_tests/gpdc_torch.py @@ -5,7 +5,7 @@ # License: GNU General Public License v3.0 from __future__ import print_function -import json, warnings +import json, warnings, os, pathlib import numpy as np import gc try: @@ -17,7 +17,7 @@ import torch import gpytorch from .LBFGS import FullBatchLBFGS - with open('../versions.py', 'r') as vfile: + with open(pathlib.Path(os.path.dirname(__file__)) / '../../versions.py', 'r') as vfile: packages = json.loads(vfile.read())['all'] packages = dict(map(lambda s: s.split('>='), packages)) if metadata.version('dcor') < packages['dcor']: diff --git a/tigramite/models.py b/tigramite/models.py index e5caabde..2f1febbc 100644 --- a/tigramite/models.py +++ b/tigramite/models.py @@ -6,7 +6,7 @@ from __future__ import print_function from copy import deepcopy -import json, warnings +import json, warnings, os, pathlib import numpy as np try: from importlib import metadata @@ -16,7 +16,7 @@ import sklearn import sklearn.linear_model import networkx - with open('../versions.py', 'r') as vfile: + with open(pathlib.Path(os.path.dirname(__file__)) / '../versions.py', 'r') as vfile: packages = json.loads(vfile.read())['all'] packages = dict(map(lambda s: s.split('>='), packages)) if metadata.version('scikit-learn') < packages['scikit-learn']: @@ -207,6 +207,7 @@ def get_general_prediction(self, intervention_data, conditions_data=None, pred_params=None, + return_further_pred_results=False, ): r"""Predict effect of intervention with fitted model. @@ -220,7 +221,9 @@ def get_general_prediction(self, Numpy array of shape (time, len(S)) that contains the S=s values. pred_params : dict, optional Optional parameters passed on to sklearn prediction function. - + return_further_pred_results : bool, optional (default: False) + In case the predictor class returns more than just the expected value, + the entire results can be returned. Returns ------- Results from prediction. @@ -244,6 +247,7 @@ def get_general_prediction(self, predicted_array = np.zeros((intervention_T, lenY)) pred_dict = {} for iy, y in enumerate(self.Y): + pred_dict[iy] = {} # Print message if self.verbosity > 1: print("\n## Predicting target %s" % str(y)) @@ -292,16 +296,28 @@ def get_general_prediction(self, # predicted_vals = a_transform.transform(X=target_array.T).T a_conditional_model = deepcopy(self.conditional_model) - a_conditional_model.fit(X=s_array, y=predicted_vals) + if type(predicted_vals) is tuple: + predicted_vals_here = predicted_vals[0] + else: + predicted_vals_here = predicted_vals + + a_conditional_model.fit(X=s_array, y=predicted_vals_here) self.fit_results[y]['conditional_model'] = a_conditional_model - predicted_array[index, iy] = a_conditional_model.predict( - X=conditions_array, **pred_params).mean() + predicted_vals = a_conditional_model.predict( + X=conditions_array, **pred_params) + # print(predicted_vals) + if type(predicted_vals) is tuple: + predicted_array[index, iy] = predicted_vals[0].mean() + pred_dict[iy][index] = predicted_vals else: predicted_array[index, iy] = predicted_vals.mean() - return predicted_array + if return_further_pred_results: + return predicted_array, pred_dict + else: + return predicted_array def get_fit(self, all_parents, diff --git a/tigramite/plotting.py b/tigramite/plotting.py index 64686f31..127121d6 100644 --- a/tigramite/plotting.py +++ b/tigramite/plotting.py @@ -5,7 +5,7 @@ # License: GNU General Public License v3.0 import numpy as np -import json, warnings +import json, warnings, os, pathlib try: from importlib import metadata except ImportError: @@ -13,7 +13,7 @@ try: import matplotlib import networkx as nx - with open('../versions.py', 'r') as vfile: + with open(pathlib.Path(os.path.dirname(__file__)) / '../versions.py', 'r') as vfile: packages = json.loads(vfile.read())['all'] packages = dict(map(lambda s: s.split('>='), packages)) if metadata.version('matplotlib') < packages['matplotlib']: