diff --git a/setup.py b/setup.py index bce0d17a..cfd147b8 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ def run(self): # Run the setup setup( name="tigramite", - version="5.0.1.8", + version="5.0.1.9", packages=["tigramite", "tigramite.independence_tests", "tigramite.toymodels"], license="GNU General Public License v3.0", description="Tigramite causal discovery for time series", diff --git a/tigramite/plotting.py b/tigramite/plotting.py index 127121d6..13e82e59 100644 --- a/tigramite/plotting.py +++ b/tigramite/plotting.py @@ -38,7 +38,7 @@ from copy import deepcopy import matplotlib.path as mpath import matplotlib.patheffects as PathEffects - +from mpl_toolkits.axisartist.axislines import Axes # TODO: Add proper docstrings to internal functions... @@ -810,6 +810,314 @@ def savefig(self, name=None): pyplot.show() + +def plot_scatterplots(dataframe, name=None, setup_args={}, add_scatterplot_args={}): + """Wrapper helper function to plot scatter plots. + Sets up the matrix object and plots the scatter plots, see parameters in + setup_scatter_matrix and add_scatterplot. + + Parameters + ---------- + dataframe : data object + Tigramite dataframe object. It must have the attributes dataframe.values + yielding a numpy array of shape (observations T, variables N) and + optionally a mask of the same shape and a missing values flag. + name : str, optional (default: None) + File name. If None, figure is shown in window. + setup_args : dict + Arguments for setting up the scatter plot matrix, see doc of + setup_scatter_matrix. + add_scatterplot_args : dict + Arguments for adding a scatter plot matrix. + + Returns + ------- + matrix : object + Further scatter plot can be overlaid using the + matrix.add_scatterplot function. + """ + + N = dataframe.N + + matrix = setup_scatter_matrix(N=N, var_names=dataframe.var_names, **setup_args) + matrix.add_scatterplot(dataframe=dataframe, **add_scatterplot_args) + + if name is not None: + matrix.savefig(name=name) + + return matrix + + +class setup_scatter_matrix: + """Create matrix of scatter plot panels. + Class to setup figure object. The function add_scatterplot allows to plot + scatterplots of variables in the dataframe. Multiple scatter plots can be + overlaid for comparison. + + Parameters + ---------- + N : int + Number of variables + var_names : list, optional (default: None) + List of variable names. If None, range(N) is used. + figsize : tuple of floats, optional (default: None) + Figure size if new figure is created. If None, default pyplot figsize + is used. + label_space_left : float, optional (default: 0.1) + Fraction of horizontal figure space to allocate left of plot for labels. + label_space_top : float, optional (default: 0.05) + Fraction of vertical figure space to allocate top of plot for labels. + legend_width : float, optional (default: 0.15) + Fraction of horizontal figure space to allocate right of plot for + legend. + plot_gridlines : bool, optional (default: False) + Whether to show a grid. + label_fontsize : int, optional (default: 10) + Fontsize of variable labels. + """ + + def __init__( + self, + N, + var_names=None, + figsize=None, + label_space_left=0.1, + label_space_top=0.05, + legend_width=0.15, + legend_fontsize=10, + plot_gridlines=False, + label_fontsize=10, + ): + + self.labels = [] + + self.legend_width = legend_width + self.legend_fontsize = legend_fontsize + + self.label_space_left = label_space_left + self.label_space_top = label_space_top + self.label_fontsize = label_fontsize + + self.fig = pyplot.figure(figsize=figsize) + + self.axes_dict = {} + + if var_names is None: + var_names = range(N) + + plot_index = 1 + for i in range(N): + for j in range(N): + self.axes_dict[(i, j)] = self.fig.add_subplot(N, N, plot_index, axes_class=Axes) + # Plot process labels + if j == 0: + trans = transforms.blended_transform_factory( + self.fig.transFigure, self.axes_dict[(i, j)].transAxes + ) + self.axes_dict[(i, j)].text( + 0.01, + 0.5, + "%s" % str(var_names[i]), + fontsize=label_fontsize, + horizontalalignment="left", + verticalalignment="center", + transform=trans, + ) + if i == 0: + trans = transforms.blended_transform_factory( + self.axes_dict[(i, j)].transAxes, self.fig.transFigure + ) + self.axes_dict[(i, j)].text( + 0.5, + 0.99, + r"${\to}$ " + "%s" % str(var_names[j]), + fontsize=label_fontsize, + horizontalalignment="center", + verticalalignment="top", + transform=trans, + ) + + self.axes_dict[(i, j)].axis["right"].set_visible(False) + self.axes_dict[(i, j)].axis["top"].set_visible(False) + + if j != 0: + self.axes_dict[(i, j)].get_yaxis().set_ticklabels([]) + if i != dataframe.N - 1: + self.axes_dict[(i, j)].get_xaxis().set_ticklabels([]) + + if plot_gridlines: + self.axes_dict[(i, j)].grid( + True, + which="major", + color="black", + linestyle="dotted", + dashes=(1, 1), + linewidth=0.05, + zorder=-5, + ) + + plot_index += 1 + + def add_scatterplot( + self, + dataframe, + scatter_lags=None, + color="black", + label=None, + marker=".", + markersize=5, + alpha=1.0, + ): + """Add lag function plot from val_matrix array. + + Parameters + ---------- + dataframe : data object + Tigramite dataframe object. It must have the attributes dataframe.values + yielding a numpy array of shape (observations T, variables N) and + optionally a mask of the same shape and a missing values flag. + scatter_lags : array + Lags to use in scatter plots. Either None or of shape (N, N). Then the + entry scatter_lags[i, j] = tau will depict the scatter plot of + time series (i, -tau) vs (j, 0). If None, tau = 0 for i != j and for i = j + tau = 1. + color : str, optional (default: 'black') + Line color. + label : str + Test statistic label. + marker : matplotlib marker symbol, optional (default: '.') + Marker. + markersize : int, optional (default: 5) + Marker size. + alpha : float, optional (default: 1.) + Opacity. + """ + + if label is not None: + self.labels.append((label, color, marker, markersize, alpha)) + + for ij in list(self.axes_dict): + i = ij[0] + j = ij[1] + if scatter_lags is None: + if i == j: + lag = 1 + else: + lag = 0 + else: + lag = scatter_lags[i,j] + if lag == 0: + x = np.copy(dataframe.values[:, i]) + y = np.copy(dataframe.values[:, j]) + else: + x = np.copy(dataframe.values[:-lag, i]) + y = np.copy(dataframe.values[lag:, j]) + if dataframe.mask is not None: + x[dataframe.mask[:-lag, i]] = np.nan + y[dataframe.mask[lag:, j]] = np.nan + # print(i, j, lag, x.shape, y.shape) + self.axes_dict[(i, j)].scatter( + x, y, + color=color, + marker=marker, + s=markersize, + alpha=alpha, + clip_on=False, + label=r"$\tau{=}%d$" %lag, + ) + # self.axes_dict[(i, j)].text(0., 1., r"$\tau{=}%d$" %lag, + # fontsize=self.legend_fontsize, + # ha='left', va='top', + # transform=self.axes_dict[(i, j)].transAxes) + + + def savefig(self, name=None): + """Save matrix figure. + + Parameters + ---------- + name : str, optional (default: None) + File name. If None, figure is shown in window. + """ + + # Trick to plot legends + colors = [] + for item in self.labels: + colors.append(item[1]) + for ij in list(self.axes_dict): + i = ij[0] + j = ij[1] + + leg = self.axes_dict[(i, j)].legend( + # loc="upper left", + ncol=1, + # bbox_to_anchor=(1.05, 0.0, 0.1, 1.0), + # borderaxespad=0, + fontsize=self.legend_fontsize-2, + labelcolor=colors, + ).draw_frame(False) + + if len(self.labels) > 0: + axlegend = self.fig.add_subplot(111, frameon=False) + axlegend.spines["left"].set_color("none") + axlegend.spines["right"].set_color("none") + axlegend.spines["bottom"].set_color("none") + axlegend.spines["top"].set_color("none") + axlegend.set_xticks([]) + axlegend.set_yticks([]) + + # self.labels.append((label, color, marker, markersize, alpha)) + for item in self.labels: + label = item[0] + color = item[1] + marker = item[2] + markersize = item[3] + alpha = item[4] + + axlegend.plot( + [], + [], + linestyle="", + color=color, + marker=marker, + markersize=markersize, + label=label, + alpha=alpha, + ) + axlegend.legend( + loc="upper left", + ncol=1, + bbox_to_anchor=(1.05, 0.0, 0.1, 1.0), + borderaxespad=0, + fontsize=self.legend_fontsize, + ).draw_frame(False) + + self.fig.subplots_adjust( + bottom=0.05, + left=self.label_space_left, + right=1.0 - self.legend_width, + top=1.0 - self.label_space_top, + hspace=0.5, + wspace=0.35, + ) + + else: + self.fig.subplots_adjust( + left=self.label_space_left, + bottom=0.05, + right=0.95, + top=1.0 - self.label_space_top, + hspace=0.35, + wspace=0.35, + ) + + + if name is not None: + self.fig.savefig(name) + else: + pyplot.show() + + def _draw_network_with_curved_edges( fig, ax, @@ -3333,6 +3641,46 @@ def _links_to_tsg(link_coeffs, max_lag=None): if __name__ == "__main__": + import sys + matplotlib.rc('xtick', labelsize=6) + matplotlib.rc('ytick', labelsize=6) + + # Consider some toy data + import tigramite + import tigramite.toymodels.structural_causal_processes as toys + import tigramite.data_processing as pp + + T = 1000 + def lin_f(x): return x + auto_coeff = 0.3 + coeff = 1. + links = { + 0: [((0, -1), auto_coeff, lin_f)], + 1: [((1, -1), auto_coeff, lin_f), ((0, 0), coeff, lin_f)], + 2: [((2, -1), auto_coeff, lin_f), ((1, 0), coeff, lin_f)], + } + data, nonstat = toys.structural_causal_process(links, T=T, + noises=None, seed=7) + + dataframe = pp.DataFrame(data, var_names=range(len(links))) + plot_scatterplots(dataframe, name='scattertest.pdf') + + # matrix = setup_scatter_matrix(N=dataframe.N, + # var_names=dataframe.var_names) + # scatter_lags = np.ones((3, 3)).astype('int') + # matrix.add_scatterplot(dataframe=dataframe, scatter_lags=scatter_lags, + # label='ones', alpha=0.4) + # scatter_lags = 2*np.ones((3, 3)).astype('int') + # matrix.add_scatterplot(dataframe=dataframe, scatter_lags=scatter_lags, + # label='twos', color='red', alpha=0.4) + + # matrix.savefig(name='scattertest.pdf') + + + # pyplot.show() + sys.exit(0) + + val_matrix = np.zeros((4, 4, 3)) # Complete test case diff --git a/tutorials/tigramite_tutorial_general_causal_effect_analysis.ipynb b/tutorials/tigramite_tutorial_general_causal_effect_analysis.ipynb index da97bbf1..66102afa 100644 --- a/tutorials/tigramite_tutorial_general_causal_effect_analysis.ipynb +++ b/tutorials/tigramite_tutorial_general_causal_effect_analysis.ipynb @@ -2055,6 +2055,26 @@ "plt.show()\n" ] }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2 0 1]\n", + "[1 2 0]\n" + ] + } + ], + "source": [ + "a=np.array([0.4, 0.8, 0.1])\n", + "print(np.argsort(a))\n", + "print(np.argsort(np.argsort(a)))" + ] + }, { "cell_type": "code", "execution_count": null,