Skip to content

Commit

Permalink
Merge pull request #191 from jakobrunge/developer
Browse files Browse the repository at this point in the history
fixed bug regarding missing values and imported package version checks
  • Loading branch information
jakobrunge authored Mar 22, 2022
2 parents 13d0372 + be5e05b commit c25889b
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 45 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ Tigramite is a causal time series analysis python package. It allows to efficien
- scikit-learn>=0.21 # Gaussian Process (GP) Regression
- matplotlib>=3.4.0 # Plotting
- networkx>=2.4 # Plotting
- torch>=1.7 # GPDC torch version
- gpytorch>=1.4 # GPDC torch version
- pytorch>=1.11.0 # GPDC pytorch version
- gpytorch>=1.4 # GPDC gpytorch version
- dcor>=0.5.3 # GPDC distance correlation version

## Installation
Expand Down Expand Up @@ -100,4 +100,4 @@ Copyright (C) 2014-2022 Jakob Runge

See license.txt for full text.

TIGRAMITE is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version. TIGRAMITE is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
TIGRAMITE is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 3 of the License, or (at your option) any later version. TIGRAMITE is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
12 changes: 8 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
from setuptools import setup, Extension
from setuptools.command.build_ext import build_ext
import json

# Handle building against numpy headers before installing numpy

Expand Down Expand Up @@ -33,22 +34,25 @@ def run(self):

# Define the minimal classes needed to install and run tigramite
# INSTALL_REQUIRES = ["numpy==1.21.4", "scipy==1.7.2", "numba==0.53.1", "six"]
INSTALL_REQUIRES = ["numpy", "scipy", "numba", "six"]
INSTALL_REQUIRES = ["numpy", "scipy", "numba", "six"]
# Define all the possible extras needed
EXTRAS_REQUIRE = {
"all": [
"scikit-learn>=0.21", # Gaussian Process (GP) Regression
"matplotlib>=3.4.0", # plotting
"networkx>=2.4", # plotting
"torch>=1.7", # GPDC torch version
"pytorch>=1.11.0", # GPDC pytorch version
"gpytorch>=1.4", # GPDC gpytorch version
"dcor>=0.5.3", # GPDC distance correlation version
]
}

with open('versions.py', 'w') as vfile:
vfile.write(json.dumps(EXTRAS_REQUIRE))

# Define the packages needed for testing
TESTS_REQUIRE = ["nose", "pytest", "networkx>=2.4", "scikit-learn>=0.21",
"torch>=1.7", "gpytorch>=1.4", "dcor>=0.5.3"]
"pytorch>=1.11.0", "gpytorch>=1.4", "dcor>=0.5.3"]
EXTRAS_REQUIRE["test"] = TESTS_REQUIRE
# Define the extras needed for development
EXTRAS_REQUIRE["dev"] = EXTRAS_REQUIRE["all"]
Expand All @@ -59,7 +63,7 @@ def run(self):
# Run the setup
setup(
name="tigramite",
version="5.0.0.8",
version="5.0.1.0",
packages=["tigramite", "tigramite.independence_tests", "tigramite.toymodels"],
license="GNU General Public License v3.0",
description="Tigramite causal discovery for time series",
Expand Down
11 changes: 10 additions & 1 deletion tests/test_construct_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def test_construct_array(cstrct_array_params):
missing_flag = data[earliest_time, a_nd]
# Record that the row with this value and all rows up to max_lag after
# this value have been cut off as well
n_rows_masked += max_lag + 1
# n_rows_masked += 1

# Construct the array
data_f = pp.DataFrame(data, data_mask, missing_flag)
Expand All @@ -122,6 +122,15 @@ def test_construct_array(cstrct_array_params):
# masked variable, which removes the first n time slices in the returned
# array
expect_array = expect_array[:, n_rows_masked:]
if missing_vals:
missing_anywhere_base = np.array(np.where(np.any(expect_array==missing_flag, axis=0))[0])
missing_anywhere = list(missing_anywhere_base)
for tau in range(1, max_lag+1):
missing_anywhere += list(np.array(missing_anywhere_base) + tau)
expect_array = np.delete(expect_array, missing_anywhere, axis=1)

# Test the results
# print(array)
# print(expect_array)
np.testing.assert_almost_equal(array, expect_array)
np.testing.assert_almost_equal(xyz, expect_xyz)
18 changes: 7 additions & 11 deletions tigramite/causal_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,6 @@ def get_optimal_set(self,
if len(self.X.intersection(self.descendants)) > 0:
return False # raise ValueError("Not identifiable: Overlap between X and des(M)")


##
## Construct O-set
##
Expand Down Expand Up @@ -1427,14 +1426,6 @@ def get_optimal_set(self,

Oset_S = Oset.union(S)

# For singleton X the validity is already checked in the
# if-statements of the construction algorithm, but for
# multivariate X there might be further cases... Hence,
# we here explicitely check validity
# if len(self.X) > 1:
# if self._check_validity(list(Oset_S)) is False:
# return False

if return_separate_sets:
return parents, colliders, collider_parents, S
else:
Expand Down Expand Up @@ -2258,7 +2249,11 @@ def lin_f(x): return x
}
data, nonstat = toys.structural_causal_process(links, T=T,
noises=None, seed=7)
dataframe = pp.DataFrame(data)

# Create some missing values
data[:10,:] = 999.
dataframe = pp.DataFrame(data, missing_flag=999.)


# Construct expert knowledge graph from links here
links = {0: [(0, -1)],
Expand All @@ -2276,8 +2271,9 @@ def lin_f(x): return x
causal_effects = CausalEffects(graph, graph_type='stationary_dag',
X=X, Y=Y, S=None,
hidden_variables=None,
verbosity=1)
verbosity=5)

print(data)
# Optimal adjustment set (is used by default)
# print(causal_effects.get_optimal_set())

Expand Down
10 changes: 5 additions & 5 deletions tigramite/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class DataFrame():
def __init__(self, data, mask=None, missing_flag=None, var_names=None,
datatime=None):

self.values = data
self.values = data.copy()
self.mask = mask
self.missing_flag = missing_flag
if self.missing_flag is not None:
Expand Down Expand Up @@ -228,16 +228,16 @@ def construct_array(self, X, Y, Z, tau_max,
# Choose which indices to use
use_indices = np.ones(time_length, dtype='int')

# Remove all values that have missing value flag, as well as the time
# Remove all values that have missing value flag, and optionally as well the time
# slices that occur up to max_lag after
if self.missing_flag is not None:
missing_anywhere = np.any(np.isnan(self.values), axis=1)
missing_anywhere = np.array(np.where(np.any(np.isnan(array), axis=0))[0])
if remove_missing_upto_maxlag:
for tau in range(max_lag+1):
if self.bootstrap is None:
use_indices[missing_anywhere[tau:T-max_lag+tau]] = 0
use_indices[missing_anywhere + tau] = 0
else:
use_indices[missing_anywhere[self.bootstrap - max_lag + tau]] = 0
use_indices[missing_anywhere[self.bootstrap] + tau] = 0
else:
if self.bootstrap is None:
use_indices[missing_anywhere] = 0
Expand Down
18 changes: 15 additions & 3 deletions tigramite/independence_tests/gpdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,27 @@
# License: GNU General Public License v3.0

from __future__ import print_function
import json, warnings
import numpy as np
import warnings
from .independence_tests_base import CondIndTest

try:
from importlib import metadata
except ImportError:
import importlib_metadata as metadata # python<=3.7
try:
import dcor
from sklearn import gaussian_process
with open('../versions.py', 'r') as vfile:
packages = json.loads(vfile.read())['all']
packages = dict(map(lambda s: s.split('>='), packages))
if metadata.version('dcor') < packages['dcor']:
raise Exception('Version mismatch. Installed version of dcor', metadata.version('dcor'),
'Please install dcor>=', packages['dcor'])
if metadata.version('scikit-learn') < packages['scikit-learn']:
raise Exception('Version mismatch. Installed version of scikit-learn', metadata.version('scikit-learn'),
'Please install scikit-learn>=', packages['scikit-learn'])
except Exception as e:
warnings.warn(str(e))
from .independence_tests_base import CondIndTest

class GaussProcReg():
r"""Gaussian processes abstract base class.
Expand Down
20 changes: 18 additions & 2 deletions tigramite/independence_tests/gpdc_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,33 @@
# License: GNU General Public License v3.0

from __future__ import print_function
import warnings
import json, warnings
import numpy as np
import gc
from .independence_tests_base import CondIndTest
try:
from importlib import metadata
except ImportError:
import importlib_metadata as metadata # python<=3.7
try:
import dcor
import torch
import gpytorch
from .LBFGS import FullBatchLBFGS
with open('../versions.py', 'r') as vfile:
packages = json.loads(vfile.read())['all']
packages = dict(map(lambda s: s.split('>='), packages))
if metadata.version('dcor') < packages['dcor']:
raise Exception('Version mismatch. Installed version of dcor', metadata.version('dcor'),
'Please install dcor>=', packages['dcor'])
if metadata.version('torch') < packages['pytorch']:
raise Exception('Version mismatch. Installed version of pytorch', metadata.version('torch'),
'Please install pytorch>=', packages['pytorch'])
if metadata.version('gpytorch') < packages['gpytorch']:
raise Exception('Version mismatch. Installed version of gpytorch', metadata.version('gpytorch'),
'Please install gpytorch>=', packages['gpytorch'])
except Exception as e:
warnings.warn(str(e))
from .independence_tests_base import CondIndTest

class GaussProcRegTorch():
r"""Gaussian processes abstract base class.
Expand Down
34 changes: 20 additions & 14 deletions tigramite/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,29 @@

from __future__ import print_function
from copy import deepcopy

import json, warnings
import numpy as np

from tigramite.data_processing import DataFrame
from tigramite.pcmci import PCMCI

# try:
import sklearn
import sklearn.linear_model
# except:
# print("Could not import sklearn...")

try:
from importlib import metadata
except ImportError:
import importlib_metadata as metadata # python<=3.7
try:
import sklearn
import sklearn.linear_model
import networkx
except:
print("Could not import networkx, LinearMediation plots not possible...")

with open('../versions.py', 'r') as vfile:
packages = json.loads(vfile.read())['all']
packages = dict(map(lambda s: s.split('>='), packages))
if metadata.version('scikit-learn') < packages['scikit-learn']:
raise Exception('Version mismatch. Installed version of scikit-learn', metadata.version('scikit-learn'),
'Please install scikit-learn>=', packages['scikit-learn'])
if metadata.version('networkx') < packages['networkx']:
raise Exception('Version mismatch. Installed version of networkx', metadata.version('networkx'),
'Please install networkx>=', packages['networkx'])
except Exception as e:
warnings.warn(str(e))
from tigramite.data_processing import DataFrame
from tigramite.pcmci import PCMCI

class Models():
"""Base class for time series models.
Expand Down
22 changes: 20 additions & 2 deletions tigramite/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,26 @@
# License: GNU General Public License v3.0

import numpy as np
import matplotlib
import json, warnings
try:
from importlib import metadata
except ImportError:
import importlib_metadata as metadata # python<=3.7
try:
import matplotlib
import networkx as nx
with open('../versions.py', 'r') as vfile:
packages = json.loads(vfile.read())['all']
packages = dict(map(lambda s: s.split('>='), packages))
if metadata.version('matplotlib') < packages['matplotlib']:
raise Exception('Version mismatch. Installed version of matplotlib', metadata.version('matplotlib'),
'Please install matplotlib>=', packages['matplotlib'])
if metadata.version('networkx') < packages['networkx']:
raise Exception('Version mismatch. Installed version of networkx', metadata.version('networkx'),
'Please install networkx>=', packages['networkx'])
except Exception as e:
warnings.warn(str(e))

from matplotlib.colors import ListedColormap
import matplotlib.transforms as transforms
from matplotlib import pyplot, ticker
Expand All @@ -15,7 +34,6 @@

import sys
from operator import sub
import networkx as nx
import tigramite.data_processing as pp
from copy import deepcopy
import matplotlib.path as mpath
Expand Down

0 comments on commit c25889b

Please sign in to comment.