From 62f39bfdd032600e0b333cd583f6b70394d33b35 Mon Sep 17 00:00:00 2001
From: jakobrunge <jakobrunge@gmail.com>
Date: Tue, 5 Apr 2022 23:29:01 +0200
Subject: [PATCH] added plot_scatterplots

---
 setup.py                                      |   2 +-
 tigramite/plotting.py                         | 350 +++++++++++++++++-
 ...orial_general_causal_effect_analysis.ipynb |  20 +
 3 files changed, 370 insertions(+), 2 deletions(-)

diff --git a/setup.py b/setup.py
index bce0d17a..cfd147b8 100644
--- a/setup.py
+++ b/setup.py
@@ -63,7 +63,7 @@ def run(self):
 # Run the setup
 setup(
     name="tigramite",
-    version="5.0.1.8",
+    version="5.0.1.9",
     packages=["tigramite", "tigramite.independence_tests", "tigramite.toymodels"],
     license="GNU General Public License v3.0",
     description="Tigramite causal discovery for time series",
diff --git a/tigramite/plotting.py b/tigramite/plotting.py
index 127121d6..13e82e59 100644
--- a/tigramite/plotting.py
+++ b/tigramite/plotting.py
@@ -38,7 +38,7 @@
 from copy import deepcopy
 import matplotlib.path as mpath
 import matplotlib.patheffects as PathEffects
-
+from mpl_toolkits.axisartist.axislines import Axes
 # TODO: Add proper docstrings to internal functions...
 
 
@@ -810,6 +810,314 @@ def savefig(self, name=None):
             pyplot.show()
 
 
+
+def plot_scatterplots(dataframe, name=None, setup_args={}, add_scatterplot_args={}):
+    """Wrapper helper function to plot scatter plots.
+    Sets up the matrix object and plots the scatter plots, see parameters in
+    setup_scatter_matrix and add_scatterplot.
+
+    Parameters
+    ----------
+    dataframe : data object
+        Tigramite dataframe object. It must have the attributes dataframe.values
+        yielding a numpy array of shape (observations T, variables N) and
+        optionally a mask of the same shape and a missing values flag.
+    name : str, optional (default: None)
+        File name. If None, figure is shown in window.
+    setup_args : dict
+        Arguments for setting up the scatter plot matrix, see doc of
+        setup_scatter_matrix.
+    add_scatterplot_args : dict
+        Arguments for adding a scatter plot matrix.
+
+    Returns
+    -------
+    matrix : object
+        Further scatter plot can be overlaid using the
+        matrix.add_scatterplot function.
+    """
+
+    N = dataframe.N
+
+    matrix = setup_scatter_matrix(N=N, var_names=dataframe.var_names, **setup_args)
+    matrix.add_scatterplot(dataframe=dataframe, **add_scatterplot_args)
+
+    if name is not None:
+        matrix.savefig(name=name)
+
+    return matrix
+
+
+class setup_scatter_matrix:
+    """Create matrix of scatter plot panels.
+    Class to setup figure object. The function add_scatterplot allows to plot
+    scatterplots of variables in the dataframe. Multiple scatter plots can be
+    overlaid for comparison.
+
+    Parameters
+    ----------
+    N : int
+        Number of variables
+    var_names : list, optional (default: None)
+        List of variable names. If None, range(N) is used.
+    figsize : tuple of floats, optional (default: None)
+        Figure size if new figure is created. If None, default pyplot figsize
+        is used.
+    label_space_left : float, optional (default: 0.1)
+        Fraction of horizontal figure space to allocate left of plot for labels.
+    label_space_top : float, optional (default: 0.05)
+        Fraction of vertical figure space to allocate top of plot for labels.
+    legend_width : float, optional (default: 0.15)
+        Fraction of horizontal figure space to allocate right of plot for
+        legend.
+    plot_gridlines : bool, optional (default: False)
+        Whether to show a grid.
+    label_fontsize : int, optional (default: 10)
+        Fontsize of variable labels.
+    """
+
+    def __init__(
+        self,
+        N,
+        var_names=None,
+        figsize=None,
+        label_space_left=0.1,
+        label_space_top=0.05,
+        legend_width=0.15,
+        legend_fontsize=10,
+        plot_gridlines=False,
+        label_fontsize=10,
+    ):
+
+        self.labels = []
+    
+        self.legend_width = legend_width
+        self.legend_fontsize = legend_fontsize
+
+        self.label_space_left = label_space_left
+        self.label_space_top = label_space_top
+        self.label_fontsize = label_fontsize
+
+        self.fig = pyplot.figure(figsize=figsize)
+
+        self.axes_dict = {}
+
+        if var_names is None:
+            var_names = range(N)
+
+        plot_index = 1
+        for i in range(N):
+            for j in range(N):
+                self.axes_dict[(i, j)] = self.fig.add_subplot(N, N, plot_index, axes_class=Axes)
+                # Plot process labels
+                if j == 0:
+                    trans = transforms.blended_transform_factory(
+                        self.fig.transFigure, self.axes_dict[(i, j)].transAxes
+                    )
+                    self.axes_dict[(i, j)].text(
+                        0.01,
+                        0.5,
+                        "%s" % str(var_names[i]),
+                        fontsize=label_fontsize,
+                        horizontalalignment="left",
+                        verticalalignment="center",
+                        transform=trans,
+                    )
+                if i == 0:
+                    trans = transforms.blended_transform_factory(
+                        self.axes_dict[(i, j)].transAxes, self.fig.transFigure
+                    )
+                    self.axes_dict[(i, j)].text(
+                        0.5,
+                        0.99,
+                        r"${\to}$ " + "%s" % str(var_names[j]),
+                        fontsize=label_fontsize,
+                        horizontalalignment="center",
+                        verticalalignment="top",
+                        transform=trans,
+                    )
+
+                self.axes_dict[(i, j)].axis["right"].set_visible(False)
+                self.axes_dict[(i, j)].axis["top"].set_visible(False)
+
+                if j != 0:
+                    self.axes_dict[(i, j)].get_yaxis().set_ticklabels([])
+                if i != dataframe.N - 1:
+                    self.axes_dict[(i, j)].get_xaxis().set_ticklabels([])
+
+                if plot_gridlines:
+                    self.axes_dict[(i, j)].grid(
+                        True,
+                        which="major",
+                        color="black",
+                        linestyle="dotted",
+                        dashes=(1, 1),
+                        linewidth=0.05,
+                        zorder=-5,
+                    )
+
+                plot_index += 1
+
+    def add_scatterplot(
+        self,
+        dataframe,
+        scatter_lags=None,
+        color="black",
+        label=None,
+        marker=".",
+        markersize=5,
+        alpha=1.0,
+    ):
+        """Add lag function plot from val_matrix array.
+
+        Parameters
+        ----------
+        dataframe : data object
+            Tigramite dataframe object. It must have the attributes dataframe.values
+            yielding a numpy array of shape (observations T, variables N) and
+            optionally a mask of the same shape and a missing values flag.
+        scatter_lags : array
+            Lags to use in scatter plots. Either None or of shape (N, N). Then the
+            entry scatter_lags[i, j] = tau will depict the scatter plot of 
+            time series (i, -tau) vs (j, 0). If None, tau = 0 for i != j and for i = j
+            tau = 1. 
+        color : str, optional (default: 'black')
+            Line color.
+        label : str
+            Test statistic label.
+        marker : matplotlib marker symbol, optional (default: '.')
+            Marker.
+        markersize : int, optional (default: 5)
+            Marker size.
+        alpha : float, optional (default: 1.)
+            Opacity.
+        """
+
+        if label is not None:
+            self.labels.append((label, color, marker, markersize, alpha))
+
+        for ij in list(self.axes_dict):                
+            i = ij[0]
+            j = ij[1]
+            if scatter_lags is None:
+                if i == j:
+                    lag = 1
+                else:
+                    lag = 0
+            else:
+                lag = scatter_lags[i,j]
+            if lag == 0:
+                x = np.copy(dataframe.values[:, i])
+                y = np.copy(dataframe.values[:, j])
+            else:
+                x = np.copy(dataframe.values[:-lag, i])
+                y = np.copy(dataframe.values[lag:, j])
+            if dataframe.mask is not None:
+                x[dataframe.mask[:-lag, i]] = np.nan
+                y[dataframe.mask[lag:, j]] = np.nan
+            # print(i, j, lag, x.shape, y.shape)
+            self.axes_dict[(i, j)].scatter(
+                x, y,
+                color=color,
+                marker=marker,
+                s=markersize,
+                alpha=alpha,
+                clip_on=False,
+                label=r"$\tau{=}%d$" %lag,
+            )
+            # self.axes_dict[(i, j)].text(0., 1., r"$\tau{=}%d$" %lag, 
+            #     fontsize=self.legend_fontsize,
+            #     ha='left', va='top',
+            #     transform=self.axes_dict[(i, j)].transAxes)
+
+
+    def savefig(self, name=None):
+        """Save matrix figure.
+
+        Parameters
+        ----------
+        name : str, optional (default: None)
+            File name. If None, figure is shown in window.
+        """
+
+        # Trick to plot legends
+        colors = []
+        for item in self.labels:
+            colors.append(item[1])
+        for ij in list(self.axes_dict):                
+            i = ij[0]
+            j = ij[1]
+
+            leg = self.axes_dict[(i, j)].legend(
+                # loc="upper left",
+                ncol=1,
+                # bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
+                # borderaxespad=0,
+                fontsize=self.legend_fontsize-2,
+                labelcolor=colors,
+                ).draw_frame(False)
+        
+        if len(self.labels) > 0:
+            axlegend = self.fig.add_subplot(111, frameon=False)
+            axlegend.spines["left"].set_color("none")
+            axlegend.spines["right"].set_color("none")
+            axlegend.spines["bottom"].set_color("none")
+            axlegend.spines["top"].set_color("none")
+            axlegend.set_xticks([])
+            axlegend.set_yticks([])
+
+            # self.labels.append((label, color, marker, markersize, alpha))
+            for item in self.labels:
+                label = item[0]
+                color = item[1]
+                marker = item[2]
+                markersize = item[3]
+                alpha = item[4]
+
+                axlegend.plot(
+                    [],
+                    [],
+                    linestyle="",
+                    color=color,
+                    marker=marker,
+                    markersize=markersize,
+                    label=label,
+                    alpha=alpha,
+                )
+            axlegend.legend(
+                loc="upper left",
+                ncol=1,
+                bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
+                borderaxespad=0,
+                fontsize=self.legend_fontsize,
+            ).draw_frame(False)
+
+            self.fig.subplots_adjust(
+                bottom=0.05,
+                left=self.label_space_left,
+                right=1.0 - self.legend_width,
+                top=1.0 - self.label_space_top,
+                hspace=0.5,
+                wspace=0.35,
+            )
+      
+        else:
+            self.fig.subplots_adjust(
+                left=self.label_space_left,
+                bottom=0.05,
+                right=0.95,
+                top=1.0 - self.label_space_top,
+                hspace=0.35,
+                wspace=0.35,
+            )
+       
+
+        if name is not None:
+            self.fig.savefig(name)
+        else:
+            pyplot.show()
+
+
 def _draw_network_with_curved_edges(
     fig,
     ax,
@@ -3333,6 +3641,46 @@ def _links_to_tsg(link_coeffs, max_lag=None):
 
 if __name__ == "__main__":
 
+    import sys
+    matplotlib.rc('xtick', labelsize=6) 
+    matplotlib.rc('ytick', labelsize=6) 
+
+    # Consider some toy data
+    import tigramite
+    import tigramite.toymodels.structural_causal_processes as toys
+    import tigramite.data_processing as pp
+
+    T = 1000
+    def lin_f(x): return x
+    auto_coeff = 0.3
+    coeff = 1.
+    links = {
+            0: [((0, -1), auto_coeff, lin_f)], 
+            1: [((1, -1), auto_coeff, lin_f), ((0, 0), coeff, lin_f)], 
+            2: [((2, -1), auto_coeff, lin_f), ((1, 0), coeff, lin_f)],
+            }
+    data, nonstat = toys.structural_causal_process(links, T=T, 
+                                noises=None, seed=7)
+
+    dataframe = pp.DataFrame(data, var_names=range(len(links)))
+    plot_scatterplots(dataframe, name='scattertest.pdf')
+    
+    # matrix = setup_scatter_matrix(N=dataframe.N, 
+    #     var_names=dataframe.var_names)
+    # scatter_lags = np.ones((3, 3)).astype('int')
+    # matrix.add_scatterplot(dataframe=dataframe, scatter_lags=scatter_lags,
+    #             label='ones', alpha=0.4)
+    # scatter_lags = 2*np.ones((3, 3)).astype('int')
+    # matrix.add_scatterplot(dataframe=dataframe, scatter_lags=scatter_lags, 
+    #     label='twos', color='red', alpha=0.4)
+
+    # matrix.savefig(name='scattertest.pdf')
+    
+
+    # pyplot.show()
+    sys.exit(0)
+
+
     val_matrix = np.zeros((4, 4, 3))
 
     # Complete test case
diff --git a/tutorials/tigramite_tutorial_general_causal_effect_analysis.ipynb b/tutorials/tigramite_tutorial_general_causal_effect_analysis.ipynb
index da97bbf1..66102afa 100644
--- a/tutorials/tigramite_tutorial_general_causal_effect_analysis.ipynb
+++ b/tutorials/tigramite_tutorial_general_causal_effect_analysis.ipynb
@@ -2055,6 +2055,26 @@
     "plt.show()\n"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": 55,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[2 0 1]\n",
+      "[1 2 0]\n"
+     ]
+    }
+   ],
+   "source": [
+    "a=np.array([0.4, 0.8, 0.1])\n",
+    "print(np.argsort(a))\n",
+    "print(np.argsort(np.argsort(a)))"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,