Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added plot_scatterplots #203

Merged
merged 1 commit into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run(self):
# Run the setup
setup(
name="tigramite",
version="5.0.1.8",
version="5.0.1.9",
packages=["tigramite", "tigramite.independence_tests", "tigramite.toymodels"],
license="GNU General Public License v3.0",
description="Tigramite causal discovery for time series",
Expand Down
350 changes: 349 additions & 1 deletion tigramite/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from copy import deepcopy
import matplotlib.path as mpath
import matplotlib.patheffects as PathEffects

from mpl_toolkits.axisartist.axislines import Axes
# TODO: Add proper docstrings to internal functions...


Expand Down Expand Up @@ -810,6 +810,314 @@ def savefig(self, name=None):
pyplot.show()



def plot_scatterplots(dataframe, name=None, setup_args={}, add_scatterplot_args={}):
"""Wrapper helper function to plot scatter plots.
Sets up the matrix object and plots the scatter plots, see parameters in
setup_scatter_matrix and add_scatterplot.

Parameters
----------
dataframe : data object
Tigramite dataframe object. It must have the attributes dataframe.values
yielding a numpy array of shape (observations T, variables N) and
optionally a mask of the same shape and a missing values flag.
name : str, optional (default: None)
File name. If None, figure is shown in window.
setup_args : dict
Arguments for setting up the scatter plot matrix, see doc of
setup_scatter_matrix.
add_scatterplot_args : dict
Arguments for adding a scatter plot matrix.

Returns
-------
matrix : object
Further scatter plot can be overlaid using the
matrix.add_scatterplot function.
"""

N = dataframe.N

matrix = setup_scatter_matrix(N=N, var_names=dataframe.var_names, **setup_args)
matrix.add_scatterplot(dataframe=dataframe, **add_scatterplot_args)

if name is not None:
matrix.savefig(name=name)

return matrix


class setup_scatter_matrix:
"""Create matrix of scatter plot panels.
Class to setup figure object. The function add_scatterplot allows to plot
scatterplots of variables in the dataframe. Multiple scatter plots can be
overlaid for comparison.

Parameters
----------
N : int
Number of variables
var_names : list, optional (default: None)
List of variable names. If None, range(N) is used.
figsize : tuple of floats, optional (default: None)
Figure size if new figure is created. If None, default pyplot figsize
is used.
label_space_left : float, optional (default: 0.1)
Fraction of horizontal figure space to allocate left of plot for labels.
label_space_top : float, optional (default: 0.05)
Fraction of vertical figure space to allocate top of plot for labels.
legend_width : float, optional (default: 0.15)
Fraction of horizontal figure space to allocate right of plot for
legend.
plot_gridlines : bool, optional (default: False)
Whether to show a grid.
label_fontsize : int, optional (default: 10)
Fontsize of variable labels.
"""

def __init__(
self,
N,
var_names=None,
figsize=None,
label_space_left=0.1,
label_space_top=0.05,
legend_width=0.15,
legend_fontsize=10,
plot_gridlines=False,
label_fontsize=10,
):

self.labels = []

self.legend_width = legend_width
self.legend_fontsize = legend_fontsize

self.label_space_left = label_space_left
self.label_space_top = label_space_top
self.label_fontsize = label_fontsize

self.fig = pyplot.figure(figsize=figsize)

self.axes_dict = {}

if var_names is None:
var_names = range(N)

plot_index = 1
for i in range(N):
for j in range(N):
self.axes_dict[(i, j)] = self.fig.add_subplot(N, N, plot_index, axes_class=Axes)
# Plot process labels
if j == 0:
trans = transforms.blended_transform_factory(
self.fig.transFigure, self.axes_dict[(i, j)].transAxes
)
self.axes_dict[(i, j)].text(
0.01,
0.5,
"%s" % str(var_names[i]),
fontsize=label_fontsize,
horizontalalignment="left",
verticalalignment="center",
transform=trans,
)
if i == 0:
trans = transforms.blended_transform_factory(
self.axes_dict[(i, j)].transAxes, self.fig.transFigure
)
self.axes_dict[(i, j)].text(
0.5,
0.99,
r"${\to}$ " + "%s" % str(var_names[j]),
fontsize=label_fontsize,
horizontalalignment="center",
verticalalignment="top",
transform=trans,
)

self.axes_dict[(i, j)].axis["right"].set_visible(False)
self.axes_dict[(i, j)].axis["top"].set_visible(False)

if j != 0:
self.axes_dict[(i, j)].get_yaxis().set_ticklabels([])
if i != dataframe.N - 1:
self.axes_dict[(i, j)].get_xaxis().set_ticklabels([])

if plot_gridlines:
self.axes_dict[(i, j)].grid(
True,
which="major",
color="black",
linestyle="dotted",
dashes=(1, 1),
linewidth=0.05,
zorder=-5,
)

plot_index += 1

def add_scatterplot(
self,
dataframe,
scatter_lags=None,
color="black",
label=None,
marker=".",
markersize=5,
alpha=1.0,
):
"""Add lag function plot from val_matrix array.

Parameters
----------
dataframe : data object
Tigramite dataframe object. It must have the attributes dataframe.values
yielding a numpy array of shape (observations T, variables N) and
optionally a mask of the same shape and a missing values flag.
scatter_lags : array
Lags to use in scatter plots. Either None or of shape (N, N). Then the
entry scatter_lags[i, j] = tau will depict the scatter plot of
time series (i, -tau) vs (j, 0). If None, tau = 0 for i != j and for i = j
tau = 1.
color : str, optional (default: 'black')
Line color.
label : str
Test statistic label.
marker : matplotlib marker symbol, optional (default: '.')
Marker.
markersize : int, optional (default: 5)
Marker size.
alpha : float, optional (default: 1.)
Opacity.
"""

if label is not None:
self.labels.append((label, color, marker, markersize, alpha))

for ij in list(self.axes_dict):
i = ij[0]
j = ij[1]
if scatter_lags is None:
if i == j:
lag = 1
else:
lag = 0
else:
lag = scatter_lags[i,j]
if lag == 0:
x = np.copy(dataframe.values[:, i])
y = np.copy(dataframe.values[:, j])
else:
x = np.copy(dataframe.values[:-lag, i])
y = np.copy(dataframe.values[lag:, j])
if dataframe.mask is not None:
x[dataframe.mask[:-lag, i]] = np.nan
y[dataframe.mask[lag:, j]] = np.nan
# print(i, j, lag, x.shape, y.shape)
self.axes_dict[(i, j)].scatter(
x, y,
color=color,
marker=marker,
s=markersize,
alpha=alpha,
clip_on=False,
label=r"$\tau{=}%d$" %lag,
)
# self.axes_dict[(i, j)].text(0., 1., r"$\tau{=}%d$" %lag,
# fontsize=self.legend_fontsize,
# ha='left', va='top',
# transform=self.axes_dict[(i, j)].transAxes)


def savefig(self, name=None):
"""Save matrix figure.

Parameters
----------
name : str, optional (default: None)
File name. If None, figure is shown in window.
"""

# Trick to plot legends
colors = []
for item in self.labels:
colors.append(item[1])
for ij in list(self.axes_dict):
i = ij[0]
j = ij[1]

leg = self.axes_dict[(i, j)].legend(
# loc="upper left",
ncol=1,
# bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
# borderaxespad=0,
fontsize=self.legend_fontsize-2,
labelcolor=colors,
).draw_frame(False)

if len(self.labels) > 0:
axlegend = self.fig.add_subplot(111, frameon=False)
axlegend.spines["left"].set_color("none")
axlegend.spines["right"].set_color("none")
axlegend.spines["bottom"].set_color("none")
axlegend.spines["top"].set_color("none")
axlegend.set_xticks([])
axlegend.set_yticks([])

# self.labels.append((label, color, marker, markersize, alpha))
for item in self.labels:
label = item[0]
color = item[1]
marker = item[2]
markersize = item[3]
alpha = item[4]

axlegend.plot(
[],
[],
linestyle="",
color=color,
marker=marker,
markersize=markersize,
label=label,
alpha=alpha,
)
axlegend.legend(
loc="upper left",
ncol=1,
bbox_to_anchor=(1.05, 0.0, 0.1, 1.0),
borderaxespad=0,
fontsize=self.legend_fontsize,
).draw_frame(False)

self.fig.subplots_adjust(
bottom=0.05,
left=self.label_space_left,
right=1.0 - self.legend_width,
top=1.0 - self.label_space_top,
hspace=0.5,
wspace=0.35,
)

else:
self.fig.subplots_adjust(
left=self.label_space_left,
bottom=0.05,
right=0.95,
top=1.0 - self.label_space_top,
hspace=0.35,
wspace=0.35,
)


if name is not None:
self.fig.savefig(name)
else:
pyplot.show()


def _draw_network_with_curved_edges(
fig,
ax,
Expand Down Expand Up @@ -3333,6 +3641,46 @@ def _links_to_tsg(link_coeffs, max_lag=None):

if __name__ == "__main__":

import sys
matplotlib.rc('xtick', labelsize=6)
matplotlib.rc('ytick', labelsize=6)

# Consider some toy data
import tigramite
import tigramite.toymodels.structural_causal_processes as toys
import tigramite.data_processing as pp

T = 1000
def lin_f(x): return x
auto_coeff = 0.3
coeff = 1.
links = {
0: [((0, -1), auto_coeff, lin_f)],
1: [((1, -1), auto_coeff, lin_f), ((0, 0), coeff, lin_f)],
2: [((2, -1), auto_coeff, lin_f), ((1, 0), coeff, lin_f)],
}
data, nonstat = toys.structural_causal_process(links, T=T,
noises=None, seed=7)

dataframe = pp.DataFrame(data, var_names=range(len(links)))
plot_scatterplots(dataframe, name='scattertest.pdf')

# matrix = setup_scatter_matrix(N=dataframe.N,
# var_names=dataframe.var_names)
# scatter_lags = np.ones((3, 3)).astype('int')
# matrix.add_scatterplot(dataframe=dataframe, scatter_lags=scatter_lags,
# label='ones', alpha=0.4)
# scatter_lags = 2*np.ones((3, 3)).astype('int')
# matrix.add_scatterplot(dataframe=dataframe, scatter_lags=scatter_lags,
# label='twos', color='red', alpha=0.4)

# matrix.savefig(name='scattertest.pdf')


# pyplot.show()
sys.exit(0)


val_matrix = np.zeros((4, 4, 3))

# Complete test case
Expand Down
Loading