Skip to content

Commit

Permalink
Merge pull request #203 from jakobrunge/developer
Browse files Browse the repository at this point in the history
added plot_scatterplots
  • Loading branch information
jakobrunge authored Apr 5, 2022
2 parents 27d6a9e + 62f39bf commit 1ff2446
Show file tree
Hide file tree
Showing 3 changed files with 370 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run(self):
# Run the setup
setup(
name="tigramite",
version="5.0.1.8",
version="5.0.1.9",
packages=["tigramite", "tigramite.independence_tests", "tigramite.toymodels"],
license="GNU General Public License v3.0",
description="Tigramite causal discovery for time series",
Expand Down
350 changes: 349 additions & 1 deletion tigramite/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from copy import deepcopy
import matplotlib.path as mpath
import matplotlib.patheffects as PathEffects

from mpl_toolkits.axisartist.axislines import Axes
# TODO: Add proper docstrings to internal functions...


Expand Down Expand Up @@ -810,6 +810,314 @@ def savefig(self, name=None):
pyplot.show()



def plot_scatterplots(dataframe, name=None, setup_args={}, add_scatterplot_args={}):
"""Wrapper helper function to plot scatter plots.
Sets up the matrix object and plots the scatter plots, see parameters in
setup_scatter_matrix and add_scatterplot.
Parameters
----------
dataframe : data object
Tigramite dataframe object. It must have the attributes dataframe.values
yielding a numpy array of shape (observations T, variables N) and
optionally a mask of the same shape and a missing values flag.
name : str, optional (default: None)
File name. If None, figure is shown in window.
setup_args : dict
Arguments for setting up the scatter plot matrix, see doc of
setup_scatter_matrix.
add_scatterplot_args : dict
Arguments for adding a scatter plot matrix.
Returns
-------
matrix : object
Further scatter plot can be overlaid using the
matrix.add_scatterplot function.
"""

N = dataframe.N

matrix = setup_scatter_matrix(N=N, var_names=dataframe.var_names, **setup_args)
matrix.add_scatterplot(dataframe=dataframe, **add_scatterplot_args)

if name is not None:
matrix.savefig(name=name)

return matrix


class setup_scatter_matrix:
"""Create matrix of scatter plot panels.
Class to setup figure object. The function add_scatterplot allows to plot
scatterplots of variables in the dataframe. Multiple scatter plots can be
overlaid for comparison.
Parameters
----------
N : int
Number of variables
var_names : list, optional (default: None)
List of variable names. If None, range(N) is used.
figsize : tuple of floats, optional (default: None)
Figure size if new figure is created. If None, default pyplot figsize
is used.
label_space_left : float, optional (default: 0.1)
Fraction of horizontal figure space to allocate left of plot for labels.
label_space_top : float, optional (default: 0.05)
Fraction of vertical figure space to allocate top of plot for labels.
legend_width : float, optional (default: 0.15)
Fraction of horizontal figure space to allocate right of plot for
legend.
plot_gridlines : bool, optional (default: False)
Whether to show a grid.
label_fontsize : int, optional (default: 10)
Fontsize of variable labels.
"""

def __init__(
self,
N,
var_names=None,
figsize=None,
label_space_left=0.1,
label_space_top=0.05,
legend_width=0.15,
legend_fontsize=10,
plot_gridlines=False,
label_fontsize=10,
):

self.labels = []

self.legend_width = legend_width
self.legend_fontsize = legend_fontsize

self.label_space_left = label_space_left
self.label_space_top = label_space_top
self.label_fontsize = label_fontsize

self.fig = pyplot.figure(figsize=figsize)

self.axes_dict = {}

if var_names is None:
var_names = range(N)

plot_index = 1
for i in range(N):
for j in range(N):
self.axes_dict[(i, j)] = self.fig.add_subplot(N, N, plot_index, axes_class=Axes)
# Plot process labels
if j == 0:
trans = transforms.blended_transform_factory(
self.fig.transFigure, self.axes_dict[(i, j)].transAxes
)
self.axes_dict[(i, j)].text(
0.01,
0.5,
"%s" % str(var_names[i]),
fontsize=label_fontsize,
horizontalalignment="left",
verticalalignment="center",
transform=trans,
)
if i == 0:
trans = transforms.blended_transform_factory(
self.axes_dict[(i, j)].transAxes, self.fig.transFigure
)
self.axes_dict[(i, j)].text(
0.5,
0.99,
r"${\to}$ " + "%s" % str(var_names[j]),
fontsize=label_fontsize,
horizontalalignment="center",
verticalalignment="top",
transform=trans,
)

self.axes_dict[(i, j)].axis["right"].set_visible(False)
self.axes_dict[(i, j)].axis["top"].set_visible(False)

if j != 0:
self.axes_dict[(i, j)].get_yaxis().set_ticklabels([])
if i != dataframe.N - 1:
self.axes_dict[(i, j)].get_xaxis().set_ticklabels([])

if plot_gridlines:
self.axes_dict[(i, j)].grid(
True,
which="major",
color="black",
linestyle="dotted",
dashes=(1, 1),
linewidth=0.05,
zorder=-5,
)

plot_index += 1

def add_scatterplot(
self,
dataframe,
scatter_lags=None,
color="black",
label=None,
marker=".",
markersize=5,
alpha=1.0,
):
"""Add lag function plot from val_matrix array.
Parameters
----------
dataframe : data object
Tigramite dataframe object. It must have the attributes dataframe.values
yielding a numpy array of shape (observations T, variables N) and
optionally a mask of the same shape and a missing values flag.
scatter_lags : array
Lags to use in scatter plots. Either None or of shape (N, N). Then the
entry scatter_lags[i, j] = tau will depict the scatter plot of
time series (i, -tau) vs (j, 0). If None, tau = 0 for i != j and for i = j
tau = 1.
color : str, optional (default: 'black')
Line color.
label : str
Test statistic label.
marker : matplotlib marker symbol, optional (default: '.')
Marker.
markersize : int, optional (default: 5)
Marker size.
alpha : float, optional (default: 1.)
Opacity.
"""

if label is not None:
self.labels.append((label, color, marker, markersize, alpha))

for ij in list(self.axes_dict):
i = ij[0]
j = ij[1]
if scatter_lags is None:
if i == j:
lag = 1
else:
lag = 0
else:
lag = scatter_lags[i,j]
if lag == 0:
x = np.copy(dataframe.values[:, i])
y = np.copy(dataframe.values[:, j])
else:
x = np.copy(dataframe.values[:-lag, i])
y = np.copy(dataframe.values[lag:, j])
if dataframe.mask is not None:
x[dataframe.mask[:-lag, i]] = np.nan
y[dataframe.mask[lag:, j]] = np.nan
# print(i, j, lag, x.shape, y.shape)
self.axes_dict[(i, j)].scatter(
x, y,
color=color,
marker=marker,
s=markersize,
alpha=alpha,
clip_on=False,
label=r"$\tau{=}%d$" %lag,
)
# self.axes_dict[(i, j)].text(0., 1., r"$\tau{=}%d$" %lag,
# fontsize=self.legend_fontsize,
# ha='left', va='top',
# transform=self.axes_dict[(i, j)].transAxes)


def savefig(self, name=None):
"""Save matrix figure.
Parameters
----------
name : str, optional (default: None)
File name. If None, figure is shown in window.
"""

# Trick to plot legends
colors = []
for item in self.labels:
colors.append(item[1])
for ij in list(self.axes_dict):
i = ij[0]
j = ij[1]

leg = self.axes_dict[(i, j)].legend(
# loc="upper left",
ncol=1,
# bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
# borderaxespad=0,
fontsize=self.legend_fontsize-2,
labelcolor=colors,
).draw_frame(False)

if len(self.labels) > 0:
axlegend = self.fig.add_subplot(111, frameon=False)
axlegend.spines["left"].set_color("none")
axlegend.spines["right"].set_color("none")
axlegend.spines["bottom"].set_color("none")
axlegend.spines["top"].set_color("none")
axlegend.set_xticks([])
axlegend.set_yticks([])

# self.labels.append((label, color, marker, markersize, alpha))
for item in self.labels:
label = item[0]
color = item[1]
marker = item[2]
markersize = item[3]
alpha = item[4]

axlegend.plot(
[],
[],
linestyle="",
color=color,
marker=marker,
markersize=markersize,
label=label,
alpha=alpha,
)
axlegend.legend(
loc="upper left",
ncol=1,
bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
borderaxespad=0,
fontsize=self.legend_fontsize,
).draw_frame(False)

self.fig.subplots_adjust(
bottom=0.05,
left=self.label_space_left,
right=1.0 - self.legend_width,
top=1.0 - self.label_space_top,
hspace=0.5,
wspace=0.35,
)

else:
self.fig.subplots_adjust(
left=self.label_space_left,
bottom=0.05,
right=0.95,
top=1.0 - self.label_space_top,
hspace=0.35,
wspace=0.35,
)


if name is not None:
self.fig.savefig(name)
else:
pyplot.show()


def _draw_network_with_curved_edges(
fig,
ax,
Expand Down Expand Up @@ -3333,6 +3641,46 @@ def _links_to_tsg(link_coeffs, max_lag=None):

if __name__ == "__main__":

import sys
matplotlib.rc('xtick', labelsize=6)
matplotlib.rc('ytick', labelsize=6)

# Consider some toy data
import tigramite
import tigramite.toymodels.structural_causal_processes as toys
import tigramite.data_processing as pp

T = 1000
def lin_f(x): return x
auto_coeff = 0.3
coeff = 1.
links = {
0: [((0, -1), auto_coeff, lin_f)],
1: [((1, -1), auto_coeff, lin_f), ((0, 0), coeff, lin_f)],
2: [((2, -1), auto_coeff, lin_f), ((1, 0), coeff, lin_f)],
}
data, nonstat = toys.structural_causal_process(links, T=T,
noises=None, seed=7)

dataframe = pp.DataFrame(data, var_names=range(len(links)))
plot_scatterplots(dataframe, name='scattertest.pdf')

# matrix = setup_scatter_matrix(N=dataframe.N,
# var_names=dataframe.var_names)
# scatter_lags = np.ones((3, 3)).astype('int')
# matrix.add_scatterplot(dataframe=dataframe, scatter_lags=scatter_lags,
# label='ones', alpha=0.4)
# scatter_lags = 2*np.ones((3, 3)).astype('int')
# matrix.add_scatterplot(dataframe=dataframe, scatter_lags=scatter_lags,
# label='twos', color='red', alpha=0.4)

# matrix.savefig(name='scattertest.pdf')


# pyplot.show()
sys.exit(0)


val_matrix = np.zeros((4, 4, 3))

# Complete test case
Expand Down
Loading

0 comments on commit 1ff2446

Please sign in to comment.