From c4ffd5749a3f072bc80f7ac56afd0a25a4b124ac Mon Sep 17 00:00:00 2001 From: CommonClimate Date: Tue, 18 Jun 2024 10:56:32 -0700 Subject: [PATCH 1/5] Update tsmodel.py --- pyleoclim/utils/tsmodel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyleoclim/utils/tsmodel.py b/pyleoclim/utils/tsmodel.py index 529cc1cc..0f3b48c1 100644 --- a/pyleoclim/utils/tsmodel.py +++ b/pyleoclim/utils/tsmodel.py @@ -726,7 +726,7 @@ def uar1_sim(t, tau, sigma_2=1): def inverse_cumsum(arr): return np.diff(np.concatenate(([0], arr))) -def random_time_axis(n, delta_t_dist = "exponential", param = [1.0]): +def random_time_axis(n, delta_t_dist = "exponential", param = [1.0], seed = None): ''' Generate a random time axis according to a specific probability model @@ -756,12 +756,14 @@ def random_time_axis(n, delta_t_dist = "exponential", param = [1.0]): ''' + if seed is not None: + np.random.seed(seed) + # check for a valid distribution valid_distributions = ["exponential", "poisson", "pareto", "random_choice"] if delta_t_dist not in valid_distributions: raise ValueError("delta_t_dist must be one of: 'exponential', 'poisson', 'pareto', 'random_choice'.") - param = np.array(param) # coerce array type if delta_t_dist == "exponential": From eb3f4e1f3acdb94912f097aefe0881da4b5f7253 Mon Sep 17 00:00:00 2001 From: CommonClimate Date: Tue, 18 Jun 2024 12:09:15 -0700 Subject: [PATCH 2/5] replaced freq_mehod by freq in Series.wavelet() and MultipleSeries.wavelet(). Also cleaned up module imports in core classes --- pyleoclim/core/coherence.py | 2 +- pyleoclim/core/multiplegeoseries.py | 9 ++++----- pyleoclim/core/multipleseries.py | 4 ++-- pyleoclim/core/multivardecomp.py | 3 --- pyleoclim/core/resolutions.py | 6 ++---- pyleoclim/core/scalograms.py | 6 ++---- pyleoclim/core/series.py | 29 ++++++++++++++++++--------- pyleoclim/core/ssares.py | 2 -- pyleoclim/tests/test_core_Series.py | 31 ++++++++++++++++++++--------- setup.py | 2 +- 10 files changed, 54 insertions(+), 40 deletions(-) diff --git a/pyleoclim/core/coherence.py b/pyleoclim/core/coherence.py index b851f51a..1c6838db 100644 --- a/pyleoclim/core/coherence.py +++ b/pyleoclim/core/coherence.py @@ -13,7 +13,7 @@ from copy import deepcopy from matplotlib.ticker import ScalarFormatter, FormatStrFormatter -from matplotlib import cm +#from matplotlib import cm from matplotlib import gridspec from tqdm import tqdm diff --git a/pyleoclim/core/multiplegeoseries.py b/pyleoclim/core/multiplegeoseries.py index b7461a42..c1db9ae3 100644 --- a/pyleoclim/core/multiplegeoseries.py +++ b/pyleoclim/core/multiplegeoseries.py @@ -7,14 +7,13 @@ from ..core.multipleseries import MultipleSeries from ..utils import mapping as mp from ..utils import plotting -import warnings -import copy + import matplotlib.pyplot as plt import matplotlib as mpl -from matplotlib import cm -from itertools import cycle -import matplotlib.lines as mlines +#from matplotlib import cm +#from itertools import cycle +#import matplotlib.lines as mlines import numpy as np #import warnings diff --git a/pyleoclim/core/multipleseries.py b/pyleoclim/core/multipleseries.py index 2cfab127..9dcdc65b 100644 --- a/pyleoclim/core/multipleseries.py +++ b/pyleoclim/core/multipleseries.py @@ -1323,7 +1323,7 @@ def spectral(self, method='lomb_scargle', freq=None, settings=None, mute_pbar=Fa return psds - def wavelet(self, method='cwt', settings={}, freq_method='log', freq_kwargs=None, verbose=False, mute_pbar=False): + def wavelet(self, method='cwt', settings={}, freq=None, freq_kwargs=None, verbose=False, mute_pbar=False): '''Wavelet analysis Parameters @@ -1403,7 +1403,7 @@ def wavelet(self, method='cwt', settings={}, freq_method='log', freq_kwargs=None scal_list = [] for s in tqdm(self.series_list, desc='Performing wavelet analysis on individual series', position=0, leave=True, disable=mute_pbar): - scal_tmp = s.wavelet(method=method, settings=settings, freq_method=freq_method, freq_kwargs=freq_kwargs, verbose=verbose) + scal_tmp = s.wavelet(method=method, settings=settings, freq=freq, freq_kwargs=freq_kwargs, verbose=verbose) scal_list.append(scal_tmp) scals = MultipleScalogram(scalogram_list=scal_list) diff --git a/pyleoclim/core/multivardecomp.py b/pyleoclim/core/multivardecomp.py index 861b45b3..0dc7bc39 100644 --- a/pyleoclim/core/multivardecomp.py +++ b/pyleoclim/core/multivardecomp.py @@ -1,10 +1,7 @@ import numpy as np #import pandas as pd from matplotlib import pyplot as plt, gridspec -#from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from matplotlib.ticker import MaxNLocator -#import cartopy.crs as ccrs -#import cartopy.feature as cfeature from ..core import series from ..utils import plotting, mapping, tsbase diff --git a/pyleoclim/core/resolutions.py b/pyleoclim/core/resolutions.py index d29b4c50..e6a05423 100644 --- a/pyleoclim/core/resolutions.py +++ b/pyleoclim/core/resolutions.py @@ -4,10 +4,8 @@ Resolution objects are designed to contain, display, and analyze information on the resolution of the time axis of a Series object. """ -from ..utils import tsutils, plotting, tsmodel, tsbase - -import warnings - +from ..utils import plotting +#import warnings import numpy as np import seaborn as sns import pandas as pd diff --git a/pyleoclim/core/scalograms.py b/pyleoclim/core/scalograms.py index a6ad21f2..0af36346 100644 --- a/pyleoclim/core/scalograms.py +++ b/pyleoclim/core/scalograms.py @@ -1,7 +1,7 @@ # It is unclear why the documentation for these two modules does not build automatically using automodule. It therefore had to be built using autoclass -from ..utils import plotting, lipdutils, tsutils +from ..utils import plotting, tsutils from ..utils import wavelet as waveutils import matplotlib.pyplot as plt @@ -9,12 +9,10 @@ from tabulate import tabulate from copy import deepcopy -from matplotlib.ticker import ScalarFormatter, FormatStrFormatter #, MaxNLocator -from mpl_toolkits.axes_grid1.inset_locator import inset_axes +from matplotlib.ticker import ScalarFormatter, FormatStrFormatter from scipy.stats.mstats import mquantiles -#from ..core import MultipleScalogram class Scalogram: ''' diff --git a/pyleoclim/core/series.py b/pyleoclim/core/series.py index 29c1c33b..6be4fc78 100644 --- a/pyleoclim/core/series.py +++ b/pyleoclim/core/series.py @@ -10,7 +10,7 @@ import datetime as dt import re -from ..utils import tsutils, plotting, tsmodel, tsbase, lipdutils, jsonutils +from ..utils import tsutils, plotting, tsbase, jsonutils from ..utils import wavelet as waveutils from ..utils import spectral as specutils from ..utils import correlation as corrutils @@ -1847,10 +1847,9 @@ def summary_plot(self, psd, scalogram, figsize=[8, 10], title=None, .. jupyter-execute:: - import pyleoclim as pyleo series = pyleo.utils.load_dataset('SOI') psd = series.spectral(freq = 'welch') - scalogram = series.wavelet(freq_method = 'welch') + scalogram = series.wavelet(freq = 'welch') fig, ax = series.summary_plot(psd = psd,scalogram = scalogram) @@ -1859,7 +1858,6 @@ def summary_plot(self, psd, scalogram, figsize=[8, 10], title=None, .. jupyter-execute:: - import pyleoclim as pyleo series = pyleo.utils.load_dataset('SOI') psd = series.spectral(freq = 'welch').signif_test(number=20) scalogram = series.wavelet(freq_method = 'welch') @@ -3066,7 +3064,7 @@ def spectral(self, method='lomb_scargle', freq=None, freq_kwargs=None, settings= return psd - def wavelet(self, method='cwt', settings=None, freq_method='log', freq_kwargs=None, verbose=False): + def wavelet(self, method='cwt', settings=None, freq=None, freq_kwargs=None, verbose=False): ''' Perform wavelet analysis on a timeseries Parameters @@ -3177,12 +3175,25 @@ def wavelet(self, method='cwt', settings=None, freq_method='log', freq_kwargs=No settings = {} if settings is None else settings.copy() freq_kwargs = {} if freq_kwargs is None else freq_kwargs.copy() - freq = specutils.make_freq_vector(self.time, method=freq_method, **freq_kwargs) - + if 'freq' in settings.keys(): + freq_vec = settings['freq'] + freq_method = "user_specified" + else: + if freq is None: # assign the frequency method automatically based on context + freq_vec = specutils.make_freq_vector(self.time, method='log', **freq_kwargs) + freq_method = "log" + elif isinstance(freq, str): # apply the specified method + freq_vec = specutils.make_freq_vector(self.time, method=freq, **freq_kwargs) + freq_method = freq + elif isinstance(freq,np.ndarray): # use the specified vector if dimensions check out + freq_vec = np.squeeze(freq) + freq_method = "user_specified" + if freq.ndim != 1: + raise ValueError("freq should be a 1-dimensional array") args = {} - args['wwz'] = {'freq': freq} - args['cwt'] = {'freq': freq} + args['wwz'] = {'freq': freq_vec} + args['cwt'] = {'freq': freq_vec} if method == 'wwz': if 'ntau' in settings.keys(): diff --git a/pyleoclim/core/ssares.py b/pyleoclim/core/ssares.py index b9d9cf82..37adf652 100644 --- a/pyleoclim/core/ssares.py +++ b/pyleoclim/core/ssares.py @@ -10,8 +10,6 @@ import seaborn as sns from matplotlib import pyplot as plt, gridspec from matplotlib.ticker import MaxNLocator - -from ..core import series from ..utils import plotting diff --git a/pyleoclim/tests/test_core_Series.py b/pyleoclim/tests/test_core_Series.py index d94641cf..99bbaf67 100644 --- a/pyleoclim/tests/test_core_Series.py +++ b/pyleoclim/tests/test_core_Series.py @@ -1042,28 +1042,41 @@ def test_wave_t1(self,wave_method): n = 100 ts = gen_ts(model='colored_noise',nt=n) freq = np.linspace(1/n, 1/2, 20) - _ = ts.wavelet(method=wave_method, settings={'freq': freq}) - + scal = ts.wavelet(method=wave_method, settings={'freq': freq}) + assert scal.freq_method == "user_specified" + def test_wave_t2(self): ''' Test Series.wavelet() ntau option and plot functionality ''' - ts = gen_ts(model='colored_noise',nt=200) + ts = gen_ts(model='colored_noise',nt=100) _ = ts.wavelet(method='wwz',settings={'ntau':10}) @pytest.mark.parametrize('mother',['MORLET', 'PAUL', 'DOG']) def test_wave_t3(self,mother): ''' Test Series.wavelet() with different mother wavelets ''' - ts = gen_ts(model='colored_noise',nt=200) - _ = ts.wavelet(method='cwt',settings={'mother':mother}) + ts = gen_ts(model='colored_noise',nt=100) + _ = ts.wavelet(settings={'mother':mother}) @pytest.mark.parametrize('freq_meth', ['log', 'scale', 'nfft', 'welch']) def test_wave_t4(self,freq_meth): - ''' Test Series.wavelet() with different mother wavelets + ''' Test Series.wavelet() with different frequency methods ''' - ts = gen_ts(model='colored_noise',nt=200) - _ = ts.wavelet(method='cwt',freq_method=freq_meth) - + ts = gen_ts(model='colored_noise',nt=100) + scal = ts.wavelet(freq=freq_meth) + assert scal.freq_method == freq_meth + + @pytest.mark.parametrize('freq', [None,np.linspace(1/100,1/2,num=20)]) + def test_wave_t5(self,freq): + ''' Test Series.wavelet() with different frequency vectors + ''' + ts = gen_ts(model='colored_noise',nt=100) + scal = ts.wavelet(freq=freq) + if freq is None: + assert scal.freq_method == 'log' + else: + assert scal.freq_method == 'user_specified' + class TestUISeriesSsa(): ''' Test the SSA functionalities ''' diff --git a/setup.py b/setup.py index cf33da89..3eec46f0 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import setup, find_packages -version = '1.0.0' +version = '1.0.0b0' # Read the readme file contents into variable def read(fname): From 69dc40201426e33cf9f531743a9590b1715e7011 Mon Sep 17 00:00:00 2001 From: CommonClimate Date: Tue, 18 Jun 2024 12:43:52 -0700 Subject: [PATCH 3/5] implemented freq in wavelet_coherence() with CI tests --- pyleoclim/core/multipleseries.py | 8 ++++-- pyleoclim/core/series.py | 43 ++++++++++++++++++++++------- pyleoclim/tests/test_core_Series.py | 20 ++++++++++++-- 3 files changed, 57 insertions(+), 14 deletions(-) diff --git a/pyleoclim/core/multipleseries.py b/pyleoclim/core/multipleseries.py index 9dcdc65b..0435f064 100644 --- a/pyleoclim/core/multipleseries.py +++ b/pyleoclim/core/multipleseries.py @@ -1341,7 +1341,12 @@ def wavelet(self, method='cwt', settings={}, freq=None, freq_kwargs=None, verbos Settings for the particular method. The default is {}. - freq_method : str; {'log', 'scale', 'nfft', 'lomb_scargle', 'welch'} + freq : str or array, optional + Information to produce the frequency vector (highly consequential for the WWZ method) + This can be 'log','scale', 'nfft', 'lomb_scargle', 'welch' or a NumPy array. + If a string, will use `make_freq_vector()` with the specified frequency-generating method. + If an array, this will be passed directly to the spectral method. + If None (default), will use the 'log' method freq_kwargs : dict @@ -1392,7 +1397,6 @@ def wavelet(self, method='cwt', settings={}, freq=None, freq_kwargs=None, verbos .. jupyter-execute:: - import pyleoclim as pyleo soi = pyleo.utils.load_dataset('SOI') nino = pyleo.utils.load_dataset('NINO3') ms = (soi & nino) diff --git a/pyleoclim/core/series.py b/pyleoclim/core/series.py index 6be4fc78..4d5f9951 100644 --- a/pyleoclim/core/series.py +++ b/pyleoclim/core/series.py @@ -3076,8 +3076,12 @@ def wavelet(self, method='cwt', settings=None, freq=None, freq_kwargs=None, verb is appropriate for unevenly-spaced series. Default is cwt, returning an error if the Series is unevenly-spaced. - freq_method : str, optional - Can be one of 'log', 'scale', 'nfft', 'lomb_scargle', 'welch'. + freq : str or array, optional + Information to produce the frequency vector (highly consequential for the WWZ method) + This can be 'log','scale', 'nfft', 'lomb_scargle', 'welch' or a NumPy array. + If a string, will use `make_freq_vector()` with the specified frequency-generating method. + If an array, this will be passed directly to the spectral method. + If None (default), will use the 'log' method freq_kwargs : dict Arguments for the frequency vector @@ -3234,7 +3238,7 @@ def wavelet(self, method='cwt', settings=None, freq=None, freq_kwargs=None, verb return scal def wavelet_coherence(self, target_series, method='cwt', settings=None, - freq_method='log', freq_kwargs=None, verbose=False, + freq=None, freq_kwargs=None, verbose=False, common_time_kwargs=None): ''' Performs wavelet coherence analysis with the target timeseries @@ -3248,8 +3252,12 @@ def wavelet_coherence(self, target_series, method='cwt', settings=None, if the series share the same evenly-spaced time axis. 'wwz' is designed for unevenly-spaced data, but is far slower. - freq_method : str - {'log','scale', 'nfft', 'lomb_scargle', 'welch'} + freq : str or array, optional + Information to produce the frequency vector (highly consequential for the WWZ method) + This can be 'log','scale', 'nfft', 'lomb_scargle', 'welch' or a NumPy array. + If a string, will use `make_freq_vector()` with the specified frequency-generating method. + If an array, this will be passed directly to the spectral method. + If None (default), will use the 'log' method freq_kwargs : dict Arguments for frequency vector @@ -3340,7 +3348,7 @@ def wavelet_coherence(self, target_series, method='cwt', settings=None, # by default, the plot function will look for the closest quantile to 0.95, but it is easy to adjust: cwt_sig.plot(signif_thresh = 0.9) - Another plotting option, `dashboard`, allows to visualize both + Another plotting option, `dashboard()`, allows to visualize both timeseries as well as the wavelet transform coherency (WTC), which quantifies where two timeseries exhibit similar behavior in time-frequency space, and the cross-wavelet transform (XWT), which indicates regions of high common power. @@ -3366,11 +3374,26 @@ def wavelet_coherence(self, target_series, method='cwt', settings=None, # Process options settings = {} if settings is None else settings.copy() freq_kwargs = {} if freq_kwargs is None else freq_kwargs.copy() - freq = specutils.make_freq_vector(self.time, method=freq_method, **freq_kwargs) + + if 'freq' in settings.keys(): + freq_vec = settings['freq'] + freq_method = "user_specified" + else: + if freq is None: # assign the frequency method automatically based on context + freq_vec = specutils.make_freq_vector(self.time, method='log', **freq_kwargs) + freq_method = "log" + elif isinstance(freq, str): # apply the specified method + freq_vec = specutils.make_freq_vector(self.time, method=freq, **freq_kwargs) + freq_method = freq + elif isinstance(freq,np.ndarray): # use the specified vector if dimensions check out + freq_vec = np.squeeze(freq) + freq_method = "user_specified" + if freq.ndim != 1: + raise ValueError("freq should be a 1-dimensional array") + args = {} - args['wwz'] = {'freq': freq, 'verbose': verbose} - args['cwt'] = {'freq': freq} - + args['wwz'] = {'freq': freq_vec, 'verbose': verbose} + args['cwt'] = {'freq': freq_vec} # put on same time axes if necessary if method == 'cwt' and not np.array_equal(self.time, target_series.time): diff --git a/pyleoclim/tests/test_core_Series.py b/pyleoclim/tests/test_core_Series.py index 99bbaf67..aa8b5b47 100644 --- a/pyleoclim/tests/test_core_Series.py +++ b/pyleoclim/tests/test_core_Series.py @@ -974,8 +974,8 @@ def test_xwave_t3(self): v_unevenly = np.delete(ts1.value, deleted_idx) t1_unevenly = np.delete(ts2.time, deleted_idx1) v1_unevenly = np.delete(ts2.value, deleted_idx1) - ts3 = pyleo.Series(time=t_unevenly, value=v_unevenly) - ts4 = pyleo.Series(time=t1_unevenly, value=v1_unevenly) + ts3 = pyleo.Series(time=t_unevenly, value=v_unevenly,auto_time_params=True) + ts4 = pyleo.Series(time=t1_unevenly, value=v1_unevenly,auto_time_params=True) _ = ts3.wavelet_coherence(ts4,method='wwz') def test_xwave_t4(self): @@ -1005,6 +1005,22 @@ def test_xwave_t6(self): ts2 = gen_ts(model='colored_noise') tau = ts1.time[::10] _ = ts1.wavelet_coherence(ts2,method='wwz',settings={'tau':tau}) + + @pytest.mark.parametrize('freq', [None,np.linspace(1/100,1/2,num=20),'log', 'nfft', 'welch']) + def test_xwave_t7(self, freq): + ''' Test Series.wavelet_coherence() with freq method argument + ''' + ts1 = gen_ts(model='colored_noise') + ts2 = gen_ts(model='colored_noise') + + coh = ts1.wavelet_coherence(ts2,freq=freq) + + if freq is None: + assert coh.freq_method == 'log' + elif isinstance(freq,np.ndarray): + assert coh.freq_method == 'user_specified' + elif isinstance(freq, str): + assert coh.freq_method == freq class TestUISeriesGlobalCoherence(): '''Test global coherence From 41358727aa4a8a558c23ae559546e3828e5f117b Mon Sep 17 00:00:00 2001 From: CommonClimate Date: Tue, 18 Jun 2024 12:53:07 -0700 Subject: [PATCH 4/5] implemented CI test for global_coherence() freq passing --- pyleoclim/tests/test_core_Series.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/pyleoclim/tests/test_core_Series.py b/pyleoclim/tests/test_core_Series.py index aa8b5b47..b35467a4 100644 --- a/pyleoclim/tests/test_core_Series.py +++ b/pyleoclim/tests/test_core_Series.py @@ -1032,13 +1032,31 @@ def test_globalcoherence_t0(self): ts2 = gen_ts(model='colored_noise') _ = ts1.global_coherence(ts2) - def test_globalcoherence_t0(self): + def test_globalcoherence_t1(self): ''' Test Series.global_coherence() with passed coh ''' ts1 = gen_ts(model='colored_noise') ts2 = gen_ts(model='colored_noise') coh = ts1.wavelet_coherence(ts2) _ = ts1.global_coherence(coh=coh) + + @pytest.mark.parametrize('freq', [None,np.linspace(1/100,1/2,num=20),'log', 'nfft', 'welch']) + def test_globalcoherence_t2(self,freq): + ''' Test Series.global_coherence() with wavelet kwargs + ''' + ts1 = gen_ts(model='colored_noise') + ts2 = gen_ts(model='colored_noise') + + kwargs = {} + kwargs['freq'] = freq + gcoh = ts1.global_coherence(target_series=ts2,wavelet_kwargs=kwargs) + + if freq is None: + assert gcoh.coh.freq_method == 'log' + elif isinstance(freq,np.ndarray): + assert gcoh.coh.freq_method == 'user_specified' + elif isinstance(freq, str): + assert gcoh.coh.freq_method == freq class TestUISeriesWavelet(): ''' Test the wavelet functionalities From 3e88a693787a64ed5b2e613e29ccff2c884a6d6b Mon Sep 17 00:00:00 2001 From: CommonClimate Date: Tue, 18 Jun 2024 14:22:13 -0700 Subject: [PATCH 5/5] consolidate coherences --- pyleoclim/core/__init__.py | 3 +- pyleoclim/core/coherences.py | 1169 +++++++++++++++++ pyleoclim/core/series.py | 4 +- ...e_Coherence.py => test_core_Coherences.py} | 38 +- pyleoclim/tests/test_core_GeoSeries.py | 1 + pyleoclim/tests/test_core_GlobalCoherence.py | 53 - 6 files changed, 1210 insertions(+), 58 deletions(-) create mode 100644 pyleoclim/core/coherences.py rename pyleoclim/tests/{test_core_Coherence.py => test_core_Coherences.py} (72%) delete mode 100644 pyleoclim/tests/test_core_GlobalCoherence.py diff --git a/pyleoclim/core/__init__.py b/pyleoclim/core/__init__.py index 2c6dd0ed..b697d96b 100644 --- a/pyleoclim/core/__init__.py +++ b/pyleoclim/core/__init__.py @@ -12,8 +12,7 @@ from .surrogateseries import SurrogateSeries from .ensembleseries import EnsembleSeries from .scalograms import Scalogram, MultipleScalogram -from .coherence import Coherence -from .globalcoherence import GlobalCoherence +from .coherences import Coherence, GlobalCoherence from .corr import Corr from .correns import CorrEns from .multivardecomp import MultivariateDecomp diff --git a/pyleoclim/core/coherences.py b/pyleoclim/core/coherences.py new file mode 100644 index 00000000..7f4e74e6 --- /dev/null +++ b/pyleoclim/core/coherences.py @@ -0,0 +1,1169 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +The Coherence class stores the result of Series.wavelet_coherence(), whether WWZ or CWT. +It includes wavelet transform coherency and cross-wavelet transform. +""" +from ..utils import plotting +from ..utils import wavelet as waveutils +from ..core.scalograms import Scalogram, MultipleScalogram + +import matplotlib.pyplot as plt +import numpy as np +from copy import deepcopy + +from matplotlib.ticker import ScalarFormatter, FormatStrFormatter +#from matplotlib import cm +from matplotlib import gridspec + +from tqdm import tqdm +from scipy.stats.mstats import mquantiles +import warnings + +class Coherence: + '''Coherence object, meant to receive the WTC and XWT part of Series.wavelet_coherence() + + See also + -------- + + pyleoclim.core.series.Series.wavelet_coherence : Wavelet coherence method + + ''' + def __init__(self, frequency, scale, time, wtc, xwt, phase, coi=None, + wave_method=None, wave_args=None, + timeseries1=None, timeseries2=None, signif_qs=None, signif_method=None, qs =None, + freq_method=None, freq_kwargs=None, Neff_threshold=3, scale_unit=None, time_label=None): + self.frequency = np.array(frequency) + self.time = np.array(time) + self.scale = np.array(scale) + self.wtc = np.array(wtc) + self.xwt = np.array(xwt) + if coi is not None: + self.coi = np.array(coi) + else: + self.coi = waveutils.make_coi(self.time, Neff_threshold=Neff_threshold) + self.phase = np.array(phase) + self.timeseries1 = timeseries1 + self.timeseries2 = timeseries2 + self.signif_qs = signif_qs + self.signif_method = signif_method + self.freq_method = freq_method + self.freq_kwargs = freq_kwargs + self.wave_method = wave_method + if wave_args is not None: + if 'freq' in wave_args.keys(): + wave_args['freq'] = np.array(wave_args['freq']) + if 'tau' in wave_args.keys(): + wave_args['tau'] = np.array(wave_args['tau']) + self.wave_args = wave_args + self.qs = qs + + if scale_unit is not None: + self.scale_unit = scale_unit + elif timeseries1 is not None: + self.scale_unit = plotting.infer_period_unit_from_time_unit(timeseries1.time_unit) + elif timeseries2 is not None: + self.scale_unit = plotting.infer_period_unit_from_time_unit(timeseries2.time_unit) + else: + self.scale_unit = None + + if time_label is not None: + self.time_label = time_label + elif timeseries1 is not None: + if timeseries1.time_unit is not None: + self.time_label = f'{timeseries1.time_name} [{timeseries1.time_unit}]' + else: + self.time_label = f'{timeseries1.time_name}' + elif timeseries2 is not None: + if timeseries2.time_unit is not None: + self.time_label = f'{timeseries2.time_name} [{timeseries2.time_unit}]' + else: + self.time_label = f'{timeseries2.time_name}' + else: + self.time_label = None + + def copy(self): + '''Copy object + ''' + return deepcopy(self) + + def plot(self, var='wtc', xlabel=None, ylabel=None, title='auto', figsize=[10, 8], + ylim=None, xlim=None, in_scale=True, yticks=None, contourf_style={}, + phase_style={}, cbar_style={}, savefig_settings={}, ax=None, + signif_clr='white', signif_linestyles='-', signif_linewidths=1, + signif_thresh = 0.95, under_clr='ivory', over_clr='black', bad_clr='dimgray'): + '''Plot the cross-wavelet results + + Parameters + ---------- + + var : str {'wtc', 'xwt'} + + variable to be plotted as color field. Default: 'wtc', the wavelet transform coherency. + 'xwt' plots the cross-wavelet transform instead. + + xlabel : str, optional + + x-axis label. The default is None. + + ylabel : str, optional + + y-axis label. The default is None. + + title : str, optional + + Title of the plot. The default is 'auto', where it is made from object metadata. + To mute, pass title = None. + + figsize : list, optional + + Figure size. The default is [10, 8]. + + ylim : list, optional + + y-axis limits. The default is None. + + xlim : list, optional + + x-axis limits. The default is None. + + in_scale : bool, optional + + Plots scales instead of frequencies The default is True. + + yticks : list, optional + + y-ticks label. The default is None. + + contourf_style : dict, optional + + Arguments for the contour plot. The default is {}. + + phase_style : dict, optional + + Arguments for the phase arrows. The default is {}. It includes: + - 'pt': the default threshold above which phase arrows will be plotted + - 'skip_x': the number of points to skip between phase arrows along the x-axis + - 'skip_y': the number of points to skip between phase arrows along the y-axis + - 'scale': number of data units per arrow length unit (see matplotlib.pyplot.quiver) + - 'width': shaft width in arrow units (see matplotlib.pyplot.quiver) + - 'color': arrow color (see matplotlib.pyplot.quiver) + + cbar_style : dict, optional + + Arguments for the color bar. The default is {}. + + savefig_settings : dict, optional + + The default is {}. + the dictionary of arguments for plt.savefig(); some notes below: + - "path" must be specified; it can be any existed or non-existed path, + with or without a suffix; if the suffix is not given in "path", it will follow "format" + - "format" can be one of {"pdf", "eps", "png", "ps"} + + ax : ax, optional + + Matplotlib axis on which to return the figure. The default is None. + + signif_thresh: float in [0, 1] + + Significance threshold. Default is 0.95. If this quantile is not + found in the qs field of the Coherence object, the closest quantile + will be picked. + + signif_clr : str, optional + + Color of the significance line. The default is 'white'. + + signif_linestyles : str, optional + + Style of the significance line. The default is '-'. + + signif_linewidths : float, optional + + Width of the significance line. The default is 1. + + under_clr : str, optional + + Color for under 0. The default is 'ivory'. + + over_clr : str, optional + + Color for over 1. The default is 'black'. + + bad_clr : str, optional + + Color for missing values. The default is 'dimgray'. + + Returns + ------- + fig, ax + + See also + -------- + + pyleoclim.core.coherence.Coherence.dashboard : plots a a dashboard showing the coherence and the cross-wavelet transform. + + pyleoclim.core.series.Series.wavelet_coherence : computes the coherence from two timeseries. + + matplotlib.pyplot.quiver : quiver plot + + Examples + -------- + + Calculate the wavelet coherence of NINO3 and All India Rainfall and plot it: + .. jupyter-execute:: + + + import pyleoclim as pyleo + ts_air = pyleo.utils.load_dataset('AIR') + ts_nino = pyleo.utils.load_dataset('NINO3') + + coh = ts_air.wavelet_coherence(ts_nino) + coh.plot() + + Establish significance against an AR(1) benchmark: + + .. jupyter-execute:: + + coh_sig = coh.signif_test(number=20, qs=[.9,.95,.99]) + coh_sig.plot() + + Note that specifiying 3 significance thresholds does not take any more time as the quantiles are + simply estimated from the same ensemble. By default, the plot function looks + for the closest quantile to 0.95, but this is easy to adjust, e.g. for the 99th percentile: + + .. jupyter-execute:: + + coh_sig.plot(signif_thresh = 0.99) + + By default, the function plots the wavelet transform coherency (WTC), which quantifies where + two timeseries exhibit similar behavior in time-frequency space, regardless of whether this + corresponds to regions of high common power. To visualize the latter, you want to plot the + cross-wavelet transform (XWT) instead, like so: + + .. jupyter-execute:: + + coh_sig.plot(var='xwt') + + ''' + if ax is None: + fig, ax = plt.subplots(figsize=figsize) + + # handling NaNs + mask_freq = [] + for i in range(np.size(self.frequency)): + if all(np.isnan(self.wtc[:, i])): + mask_freq.append(False) + else: + mask_freq.append(True) + + if in_scale: + y_axis = self.scale[mask_freq] + if ylabel is None: + ylabel = f'Scale [{self.scale_unit}]' if self.scale_unit is not None else 'Scale' + + if yticks is None: + yticks_default = np.array([0.1, 0.2, 0.5, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 1e4, 2e4, 5e4, 1e5, 2e5, 5e5, 1e6]) + mask = (yticks_default >= np.min(y_axis)) & (yticks_default <= np.max(y_axis)) + yticks = yticks_default[mask] + else: + y_axis = self.frequency[mask_freq] + if ylabel is None: + ylabel = f'Frequency [1/{self.scale_unit}]' if self.scale_unit is not None else 'Frequency' + + if signif_thresh > 1 or signif_thresh < 0: + raise ValueError("The significance threshold must be in [0, 1] ") + + # plot color field for WTC or XWT + contourf_args = { + 'cmap': 'magma', + 'origin': 'lower', + } + contourf_args.update(contourf_style) + + cmap = plt.get_cmap(contourf_args['cmap']) + cmap.set_under(under_clr) + cmap.set_over(over_clr) + cmap.set_bad(bad_clr) + contourf_args['cmap'] = cmap + + if var == 'wtc': + lev = np.linspace(0, 1, 11) + cont = ax.contourf(self.time, y_axis, self.wtc[:, mask_freq].T, + levels = lev, **contourf_args) + elif var == 'xwt': + cont = ax.contourf(self.time, y_axis, self.xwt[:, mask_freq].T, + levels = 11, **contourf_args) # just pass number of contours + else: + raise ValueError("Unknown variable; please choose either 'wtc' or 'xwt'") + + # plot significance levels + if self.signif_qs is not None: + signif_method_label = { + 'ar1': 'AR(1)', + } + if signif_thresh not in self.qs: + isig = np.abs(np.array(self.qs) - signif_thresh).argmin() + print("Significance threshold {:3.2f} not found in qs. Picking the closest, which is {:3.2f}".format(signif_thresh,self.qs[isig])) + else: + isig = self.qs.index(signif_thresh) + + if var == 'wtc': + signif_coh = self.signif_qs[0].scalogram_list[isig] # extract WTC significance threshold + signif_boundary = self.wtc[:, mask_freq].T / signif_coh.amplitude[:, mask_freq].T + elif var == 'xwt': + signif_coh = self.signif_qs[1].scalogram_list[isig] # extract XWT significance threshold + signif_boundary = self.xwt[:, mask_freq].T / signif_coh.amplitude[:, mask_freq].T + + ax.contour(self.time, y_axis, signif_boundary, [-99, 1], + colors=signif_clr, + linestyles=signif_linestyles, + linewidths=signif_linewidths) + if title is not None: + ax.set_title("Lines:" + str(round(self.qs[isig]*100))+"% threshold") + + # plot colorbar + cbar_args = { + 'label': var.upper(), + 'drawedges': False, + 'orientation': 'vertical', + 'fraction': 0.15, + 'pad': 0.05, + 'ticks': cont.levels + } + cbar_args.update(cbar_style) + + # assign colorbar to axis (instead of fig) : https://matplotlib.org/stable/gallery/subplots_axes_and_figures/colorbar_placement.html + cb = plt.colorbar(cont, ax = ax, **cbar_args) + + # plot cone of influence + ax.set_yscale('log') + ax.plot(self.time, self.coi, 'k--') + + if ylim is None: + ylim = [np.min(y_axis), np.min([np.max(y_axis), np.max(self.coi)])] + + ax.fill_between(self.time, self.coi, np.max(self.coi), color='white', alpha=0.5) + + if yticks is not None: + ax.set_yticks(yticks) + ax.yaxis.set_major_formatter(ScalarFormatter()) + ax.yaxis.set_major_formatter(FormatStrFormatter('%g')) + + if xlabel is None: + xlabel = self.time_label + + if xlabel is not None: + ax.set_xlabel(xlabel) + + if ylabel is not None: + ax.set_ylabel(ylabel) + + # plot phase + skip_x = np.max([int(np.size(self.time)//20), 1]) + skip_y = np.max([int(np.size(y_axis)//20), 1]) + + phase_args = {'pt': 0.5, 'skip_x': skip_x, 'skip_y': skip_y, + 'scale': 30, 'width': 0.004} + phase_args.update(phase_style) + + pt = phase_args['pt'] + skip_x = phase_args['skip_x'] + skip_y = phase_args['skip_y'] + scale = phase_args['scale'] + width = phase_args['width'] + + if 'color' in phase_style: + color = phase_style['color'] + else: + color = 'black' + + phase = np.copy(self.phase)[:, mask_freq] + + if self.signif_qs is None: + if var == 'wtc': + phase[self.wtc[:, mask_freq] < pt] = np.nan + else: + field = self.xwt[:, mask_freq] + phase[field < pt*field.max()] = np.nan + else: + phase[signif_boundary.T < 1] = np.nan + + X, Y = np.meshgrid(self.time, y_axis) + U, V = np.cos(phase).T, np.sin(phase).T + + ax.quiver(X[::skip_y, ::skip_x], Y[::skip_y, ::skip_x], + U[::skip_y, ::skip_x], V[::skip_y, ::skip_x], + scale=scale, width=width, zorder=99, color=color) + + ax.set_ylim(ylim) + + if xlim is not None: + ax.set_xlim(xlim) + + lbl1 = self.timeseries1.label + lbl2 = self.timeseries2.label + + if 'fig' in locals(): + if 'path' in savefig_settings: + plotting.savefig(fig, settings=savefig_settings) + if title is not None and title != 'auto': + fig.suptitle(title) + elif title == 'auto' and lbl1 is not None and lbl1 is not None: + title = 'Wavelet coherency ('+self.wave_method.upper() +') between '+ lbl1 + ' and ' + lbl2 + fig.suptitle(title) + return fig, ax + else: + return ax + + + def dashboard(self, title=None, figsize=[9,12], overlap = True, phase_style = {}, + line_colors = ['tab:blue','tab:orange'], savefig_settings={}, + ts_plot_kwargs = None, wavelet_plot_kwargs= None): + ''' Cross-wavelet dashboard, including the two series, their WTC and XWT. + + Note: this design balances many considerations, and is not easily customizable. + + Parameters + ---------- + title : str, optional + + Title of the plot. The default is None. + + figsize : list, optional + + Figure size. The default is [9, 12], as this is an information-rich figure. + + overlap : boolean, optional + whether to restrict the plot to the period of overlap between the series. Defaults to True + + phase_style : dict, optional + + Arguments for the phase arrows. The default is {}. It includes: + - 'pt': the default threshold above which phase arrows will be plotted + - 'skip_x': the number of points to skip between phase arrows along the x-axis + - 'skip_y': the number of points to skip between phase arrows along the y-axis + - 'scale': number of data units per arrow length unit (see matplotlib.pyplot.quiver) + - 'width': shaft width in arrow units (see matplotlib.pyplot.quiver) + - 'color': arrow color (see matplotlib.pyplot.quiver) + + line_colors : list, optional + + Colors for the 2 traces For nomenclature, see https://matplotlib.org/stable/gallery/color/named_colors.html + + savefig_settings : dict, optional + + The default is {}. + the dictionary of arguments for plt.savefig(); some notes below: + - "path" must be specified; it can be any existed or non-existed path, + with or without a suffix; if the suffix is not given in "path", it will follow "format" + - "format" can be one of {"pdf", "eps", "png", "ps"} + + ts_plot_kwargs : dict + + arguments to be passed to the timeseries subplot, see pyleoclim.core.series.Series.plot for details + + wavelet_plot_kwargs : dict + + arguments to be passed to the contour subplots (XWT and WTC), [see pyleoclim.core.coherence.Coherence.plot for details] + + + Returns + ------- + fig, ax + + See also + -------- + + pyleoclim.core.coherence.Coherence.plot : creates a coherence plot + + pyleoclim.core.series.Series.wavelet_coherence : computes the coherence between two timeseries. + + pyleoclim.core.series.Series.plot: plots a timeseries + + matplotlib.pyplot.quiver: makes a quiver plot + + Examples + -------- + + Calculate the coherence of NINO3 and All India Rainfall and plot it as a dashboard: + + .. jupyter-execute:: + + import pyleoclim as pyleo + ts_air = pyleo.utils.load_dataset('AIR') + ts_nino = pyleo.utils.load_dataset('NINO3') + + coh = ts_air.wavelet_coherence(ts_nino) + coh_sig = coh.signif_test(number=10) + + coh_sig.dashboard() + + You may customize colors like so: + + .. jupyter-execute:: + + coh_sig.dashboard(line_colors=['teal','gold']) + + To export the figure, use `savefig_settings`: + + .. jupyter-execute:: + + coh_sig.dashboard(savefig_settings={'path':'./coh_dash.png','dpi':300}) + + ''' + # prepare options dictionaries + savefig_settings = {} if savefig_settings is None else savefig_settings.copy() + wavelet_plot_kwargs={} if wavelet_plot_kwargs is None else wavelet_plot_kwargs.copy() + ts_plot_kwargs={} if ts_plot_kwargs is None else ts_plot_kwargs.copy() + + + # create figure + fig = plt.figure(figsize=figsize) + gs = gridspec.GridSpec(8, 1) + gs.update(wspace=0, hspace=0.5) # add some breathing room + ax = {} + + # assess period of overlap + xlims = np.min(self.time), np.max(self.time) + + # 1) plot timeseries + #plt.rc('ytick', labelsize=8) + ax['ts1'] = plt.subplot(gs[0:2, 0]) + self.timeseries1.plot(ax=ax['ts1'], color=line_colors[0], **ts_plot_kwargs, legend=False) + ax['ts1'].yaxis.label.set_color(line_colors[0]) + ax['ts1'].tick_params(axis='y', colors=line_colors[0],labelsize=8) + ax['ts1'].spines['left'].set_color(line_colors[0]) + ax['ts1'].spines['bottom'].set_visible(False) + ax['ts1'].grid(False) + ax['ts1'].set_xlabel('') + if overlap: + ax['ts1'].set_xlim(xlims) + + ax['ts2'] = ax['ts1'].twinx() + self.timeseries2.plot(ax=ax['ts2'], color=line_colors[1], **ts_plot_kwargs, legend=False) + ax['ts2'].yaxis.label.set_color(line_colors[1]) + ax['ts2'].tick_params(axis='y', colors=line_colors[1],labelsize=8) + ax['ts2'].spines['right'].set_color(line_colors[1]) + ax['ts2'].spines['right'].set_visible(True) + ax['ts2'].spines['left'].set_visible(False) + ax['ts2'].grid(False) + if overlap: + ax['ts2'].set_xlim(xlims) + + # 2) plot WTC + ax['wtc'] = plt.subplot(gs[2:5, 0], sharex=ax['ts1']) + if 'cbar_style' not in wavelet_plot_kwargs: + wavelet_plot_kwargs.update({'cbar_style':{'orientation': 'horizontal', + 'pad': 0.15, 'aspect': 60}}) + self.plot(var='wtc',ax=ax['wtc'], title= None, **wavelet_plot_kwargs) + #ax['wtc'].xaxis.set_visible(False) # hide x axis + ax['wtc'].set_xlabel('') + + # 3) plot XWT + ax['xwt'] = plt.subplot(gs[5:8, 0], sharex=ax['ts1']) + if 'phase_style' not in wavelet_plot_kwargs: + wavelet_plot_kwargs.update({'phase_style':{'color': 'lightgray'}}) + self.plot(var='xwt',ax=ax['xwt'], title= None, + contourf_style={'cmap': 'viridis'}, + cbar_style={'orientation': 'horizontal','pad': 0.2, 'aspect': 60}, + phase_style=wavelet_plot_kwargs['phase_style']) + + #gs.tight_layout(fig) # this does nothing + + if 'fig' in locals(): + if 'path' in savefig_settings: + plotting.savefig(fig, settings=savefig_settings) + return fig, ax + else: + return ax + + def signif_test(self, number=200, method='ar1sim', seed=None, qs=[0.95], settings=None, mute_pbar=False): + '''Significance testing for Coherence objects + + The method obtains quantiles `qs` of the distribution of coherence between + `number` pairs of Monte Carlo simulations of a process that resembles the original series. + Currently, only AR(1) surrogates are supported. + + Parameters + ---------- + number : int, optional + + Number of surrogate series to create for significance testing. The default is 200. + + method : {'ar1sim','phaseran','CN'}, optional + + Method through which to generate the surrogate series. The default is 'phaseran'. + + seed : int, optional + + Fixes the seed for NumPy's random number generator. + Useful for reproducibility. The default is None, so fresh, unpredictable + entropy will be pulled from the operating system. + + qs : list, optional + + Significance levels to return. The default is [0.95]. + + settings : dict, optional + + Parameters for surrogate model. The default is None. + + mute_pbar : bool, optional + + Mute the progress bar. The default is False. + + Returns + ------- + new : pyleoclim.core.coherence.Coherence + + original Coherence object augmented with significance levels signif_qs, + a list with the following `MultipleScalogram` objects: + * 0: MultipleScalogram for the wavelet transform coherency (WTC) + * 1: MultipleScalogram for the cross-wavelet transform (XWT) + + Each object contains as many Scalogram objects as qs contains values + + See also + -------- + + pyleoclim.core.series.Series.wavelet_coherence : Wavelet coherence + + pyleoclim.core.scalograms.Scalogram : Scalogram object + + pyleoclim.core.scalograms.MultipleScalogram : Multiple Scalogram object + + pyleoclim.core.coherence.Coherence.plot : plotting method for Coherence objects + + Examples + -------- + + Calculate the coherence of NINO3 and All India Rainfall and assess significance: + + .. jupyter-execute:: + + import pyleoclim as pyleo + ts_air = pyleo.utils.load_dataset('AIR') + ts_nino = pyleo.utils.load_dataset('NINO3') + + coh = ts_air.wavelet_coherence(ts_nino) + coh_sig = coh.signif_test(number=20) + coh_sig.plot() + + By default, significance is assessed against a 95% benchmark derived from + an AR(1) process fit to the data, using 200 Monte Carlo simulations. + To customize, one can increase the number of simulations + (more reliable, but slower), and the quantile levels. + + .. jupyter-execute:: + + coh_sig2 = coh.signif_test(number=100, qs=[.9,.95,.99]) + coh_sig2.plot() + + The plot() function will represent the 95% level as contours by default. + If you need to show 99%, say, use the `signif_thresh` argument: + + .. jupyter-execute:: + + coh_sig2.plot(signif_thresh=0.99) + + Note that if the 99% quantile is not present, the plot method will look + for the closest match, but lines are always labeled appropriately. + For reproducibility purposes, it may be good to specify the (pseudo)random number + generator's seed, like so: + + .. jupyter-execute:: + + coh_sig27 = coh.signif_test(number=20, seed=27) + + This will generate exactly the same set of draws from the + (pseudo)random number at every execution, which may be important for marginal features + in small ensembles. In general, however, we recommend increasing the + number of draws to check that features are robust. + + One can also specifiy a different method to obtain surrogates, e.g. phase randomization: + + .. jupyter-execute:: + + coh.signif_test(method='phaseran').plot() + ''' + from ..core.surrogateseries import SurrogateSeries + + if number == 0: + return self + + new = self.copy() + + surr1 = SurrogateSeries(method=method,number=number, seed=seed) + surr1.from_series(self.timeseries1) + surr2 = SurrogateSeries(method=method,number=number, seed=seed) + surr2.from_series(self.timeseries2) + + # adjust time axis + + wtcs, xwts = [], [] + + for i in tqdm(range(number), desc='Performing wavelet coherence on surrogate pairs', total=number, disable=mute_pbar): + coh_tmp = surr1.series_list[i].wavelet_coherence(surr2.series_list[i], + method = self.wave_method, + settings = self.wave_args) + wtcs.append(coh_tmp.wtc) + xwts.append(coh_tmp.xwt) + + wtcs = np.array(wtcs) + xwts = np.array(xwts) + + + ne, nf, nt = np.shape(wtcs) + + # reshape because mquantiles only accepts inputs of at most 2D + wtcs_r = np.reshape(wtcs, (ne, nf*nt)) + xwts_r = np.reshape(xwts, (ne, nf*nt)) + + # define nd-arrays + nq = len(qs) + wtc_qs = np.ndarray(shape=(nq, nf, nt)) + xwt_qs = np.empty_like(wtc_qs) + + # extract quantiles and reshape + wtc_qs = mquantiles(wtcs_r, qs, axis=0) + wtc_qs = np.reshape(wtc_qs, (nq, nf, nt)) + xwt_qs = mquantiles(xwts_r, qs, axis=0) + xwt_qs = np.reshape(xwt_qs, (nq, nf, nt)) + + # put in Scalogram objects for export + wtc_list, xwt_list = [],[] + + for i in range(nq): + wtc_tmp = Scalogram( + frequency=self.frequency, time=self.time, amplitude=wtc_qs[i,:,:], + coi=self.coi, scale = self.scale, + freq_method=self.freq_method, freq_kwargs=self.freq_kwargs, label=f'{qs[i]*100:g}%', + ) + wtc_list.append(wtc_tmp) + xwt_tmp = Scalogram( + frequency=self.frequency, time=self.time, amplitude=xwt_qs[i,:,:], + coi=self.coi, scale = self.scale, + freq_method=self.freq_method, freq_kwargs=self.freq_kwargs, label=f'{qs[i]*100:g}%', + ) + + xwt_list.append(xwt_tmp) + + new.signif_qs = [] + new.signif_qs.append(MultipleScalogram(scalogram_list=wtc_list)) # Export WTC quantiles + new.signif_qs.append(MultipleScalogram(scalogram_list=xwt_list)) # Export XWT quantiles + new.signif_method = method + new.qs = qs + + return new + + def phase_stats(self, scales, number=1000, level=0.05): + ''' Estimate phase angle statistics of a Coherence object + + As per [1], the strength (consistency) of a phase relationship is assessed using: + + * sigma, the circular standard deviation + + * kappa, an estimate of the Von Mises distribution's concentration parameter. + It is a reciprocal measure of dispersion, so 1/kappa is analogous to the variance) [3]. + + Because of inherent persistence of geophysical signals and of the + reproducing kernel of the continuous wavelet transform [3], phase statistics are + assessed relative to an AR(1) model fit to the angle deviations observed at the requested scale(s). + + Specifically, if `number` is specified, the method simulates `number` + Monte Carlo realizations of an AR(1) process fit to fluctuations around + the mean angle. This ensemble is used to obtain the confidence limits: + `sigma_lo` (`level` quantile) and `kappa_hi` (1-`level` quantile). + These correspond to 1-tailed tests of the strength of the relationship. + + Parameters + ---------- + scales : float + + scale at which to evaluate the phase angle + + number : int, optional + + number of AR(1) series to create for significance testing. The default is 1000. + + level : float, optional + + significance level against which to gauge sigma and kappa. default: 0.05 + + + Returns + ------- + result : dict + + contains angle_mean (the mean angle for those scales), sigma (the + circular standard deviation), kappa, sigma_lo (alpha-level quantile + for sigma) and kappa_hi, the (1-alpha)-level quantile for kappa. + + See also + -------- + + pyleoclim.core.series.Series.wavelet_coherence : Wavelet coherence + + pyleoclim.core.scalograms.Scalogram : Scalogram object + + pyleoclim.core.scalograms.MultipleScalogram : Multiple Scalogram object + + pyleoclim.core.coherence.Coherence.plot : plotting method for Coherence objects + + pyleoclime.utils.wavelet.angle_sig : significance of phase angle statistics + + pyleoclim.utils.wavelet.angle_stats: phase angle statistics + + + References + ---------- + + [1] Grinsted, A., J. C. Moore, and S. Jevrejeva (2004), Application of the cross + wavelet transform and wavelet coherence to geophysical time series, + Nonlinear Processes in Geophysics, 11, 561–566. + + [2] Huber, R., Dutra, L. V., & da Costa Freitas, C. (2001). + SAR interferogram phase filtering based on the Von Mises distribution. + In IGARSS 2001. Scanning the Present and Resolving the Future. + Proceedings. IEEE 2001 International Geoscience and Remote Sensing Symposium + (Cat. No. 01CH37217) (Vol. 6, pp. 2816-2818). IEEE. + + [3] Farge, M. and Schneider, K. (2006): Wavelets: application to turbulence + Encyclopedia of Mathematical Physics (Eds. J.-P. Françoise, G. Naber and T.S. Tsun) + pp 408-420. + + Examples + -------- + + Calculate the phase angle between NINO3 and All India Rainfall at 5y scales: + + .. jupyter-execute:: + + import pyleoclim as pyleo + ts_air = pyleo.utils.load_dataset('AIR') + ts_nino = pyleo.utils.load_dataset('NINO3') + coh = ts_air.wavelet_coherence(ts_nino) + coh.phase_stats(scales=5) + + One may also obtain phase angle statistics over an interval, like the 2-8y ENSO band: + + .. jupyter-execute:: + + phase = coh.phase_stats(scales=[2,8]) + print("The mean angle is {:4.2f}°".format(phase.mean_angle/np.pi*180)) + print(phase) + + From this example, one diagnoses a strong anti-phased relationship in the ENSO band, + with high von Mises concentration (kappa ~ 3.35 >> kappa_hi) and low circular + dispersion (sigma ~ 0.6 << sigma_lo). This would be strong evidence of a consistent + anti-phasing between NINO3 and AIR at those scales. + + ''' + scales = np.array(scales) + + if scales.max() > self.scale.max(): + warnings.warn("Requested scale exceeds largest scale in object. Truncating to "+str(self.scale.max())) + + if scales.size == 1: + scale_idx = np.argmin(np.abs(self.scale - scales)) + res = waveutils.angle_sig(self.phase[:,scale_idx],nMC=number,level=level) + elif scales.size == 2: + idx_lo = np.argmin(np.abs(self.scale - scales.min())) + idx_hi = np.argmin(np.abs(self.scale - scales.max())) + if (idx_hi >= idx_lo): + raise ValueError("Insufficiently spaced scales. Please pick a single one, or a wider interval") + else: # average phase over those scales + nt, ns = self.phase.shape + phase = np.empty((nt)) + for i in range(nt): + phase[i], _, _ = waveutils.angle_stats(self.phase[i,idx_hi:idx_lo]) + res = waveutils.angle_sig(phase,nMC=number,level=level) # assess significance + + return res + +class GlobalCoherence: + '''Class to store the results of cross spectral analysis + + Attributes + ---------- + + global_coh: numpy array + coherence values + + scale: numpy array + scale values + + frequency: numpy array + frequency values + + coi: numpy array + cone of influence values + + coh: Coherence + Original coherence object + + + See Also + -------- + + pyleoclim.core.series.Series.global_coherence : method to compute the spectral coherence''' + + def __init__(self, global_coh, coh, signif_qs=None,signif_method=None,qs=None, label='Coherence'): + self.global_coh = global_coh + self.label = label + self.coh = coh + self.signif_qs = signif_qs + self.signif_method = signif_method + self.qs = qs + + def copy(self): + '''Copy object + ''' + return deepcopy(self) + + def signif_test(self,method='ar1sim',number=200,qs=[.95]): + '''Perform a significance test on the coherence values + + Parameters + ---------- + method: str; {'ar1sim','CN','phaseran'} + method to use for the surrogate test. Default is 'ar1sim'. + + number: int + number of surrogates to generate. Default is 200 + + qs: list + list of quantiles to compute. Default is [.95] + + Returns + ------- + global_coh: pyleoclim.core.globalcoherence.GlobalCoherence + Global coherence with significance field filled in + + Examples + -------- + + .. jupyter-execute:: + + soi = pyleo.utils.load_dataset('SOI') + nino3 = pyleo.utils.load_dataset('NINO3') + + gcoh = soi.global_coherence(nino3) + gcoh_sig = gcoh.signif_test(number=10) + gcoh_sig.plot() + ''' + + from ..core.surrogateseries import SurrogateSeries + + new = self.copy() + + ts1 = self.coh.timeseries1 + ts2 = self.coh.timeseries2 + + surr1 = SurrogateSeries(method=method,number=number) + surr2 = SurrogateSeries(method=method,number=number) + + surr1.from_series(ts1) + surr2.from_series(ts2) + + coh_array = np.empty((number,len(self.global_coh))) + + wavelet_kwargs = { + 'freq':self.coh.frequency, # pass the frequency axis directly + 'settings':self.coh.wave_args, + 'method':self.coh.wave_method, + } + + for i in range(number): + surr_series1 = surr1.series_list[i] + surr_series2 = surr2.series_list[i] + surr_coh = surr_series1.global_coherence(surr_series2,wavelet_kwargs=wavelet_kwargs) + coh_array[i,:] = surr_coh.global_coh + + quantiles = mquantiles(coh_array,qs,axis=0) + new.signif_qs = quantiles.data + new.signif_method = method + new.qs = qs + + return new + + def plot(self,figsize=(8,8),xlim=None,xlabel=None,label=None,psd_y_label='PSD',coh_y_label='Coherence',coh_line_color='grey',ax=None,coh_ylim=(.4,1),fill_alpha=.3,fill_color='grey',coh_plot_kwargs=None, + savefig_settings=None,spectral_kwargs=None,legend=True,legend_kwargs=None,spec1_plot_kwargs=None,spec2_plot_kwargs=None): + '''Plot the coherence as a function of scale or frequency, alongside the spectrum of the two timeseries (using the same method used for the coherence). + + Parameters + ---------- + figsize: tuple + size of the figure. Default is (8,8). Only used if ax is None + + xlim: tuple + x limits for the plot. Default is None + + label: str + label of the plot + + xlabel: str + x label of the plot + + psd_y_label: str + y label of the power spectral density plot (left hand side) + + coh_y_label: str + y label of the coherence plot (right hand side) + + coh_line_color: str + color of the coherence line + + coh_ylim: tuple + y limits for the coherence plot. Default is (.4,1) + + fill_alpha: float + alpha value for the fill_between plot. Default is .3 + + fill_color : str + color of the fill_between plot + + coh_plot_kwargs: dict + additional arguments to pass to the pyleoclim.utils.plotting.plot_xy + + savefig_settings: dict + settings to pass to the pyleoclim.utils.plotting.savefig function + + spectral_kwargs: dict + additional arguments to pass to the pyleo.Series.spectral method + + spec1_plot_kwargs: dict + additional arguments to pass to the pyleo.Series.spectral method + + spec2_plot_kwargs: dict + additional arguments to pass to the pyleo.Series.spectral method + + legend: bool + whether to include a legend or not + + legend_kwargs: dict + additional arguments to pass to ax.legend + + ax: matplotlib axis + axis to plot on + + Returns + ------- + ax: matplotlib axis + axis with the plot + + Examples + -------- + + .. jupyter-execute:: + + soi = pyleo.utils.load_dataset('SOI') + nino3 = pyleo.utils.load_dataset('NINO3') + + gcoh = soi.global_coherence(nino3) + gcoh.plot()''' + + coh_plot_kwargs = {} if coh_plot_kwargs is None else coh_plot_kwargs.copy() + savefig_settings = {} if savefig_settings is None else savefig_settings.copy() + spectral_kwargs = {} if spectral_kwargs is None else spectral_kwargs.copy() + legend_kwargs = {} if legend_kwargs is None else legend_kwargs.copy() + spec1_plot_kwargs = {} if spec1_plot_kwargs is None else spec1_plot_kwargs.copy() + spec2_plot_kwargs = {} if spec2_plot_kwargs is None else spec2_plot_kwargs.copy() + + if ax is None: + fig,ax = plt.subplots(figsize=figsize) + else: + pass + + coh_dict = self.coh.__dict__ + + if 'method' not in spectral_kwargs: + spectral_kwargs.update({'method': coh_dict['wave_method']}) + if 'freq' not in spectral_kwargs: + spectral_kwargs.update({'freq': coh_dict['freq_method']}) + if 'freq_kwargs' not in spectral_kwargs: + spectral_kwargs.update({'freq_kwargs': coh_dict['freq_kwargs']}) + if spectral_kwargs['method'] == coh_dict['wave_method']: + for key,value in coh_dict['wave_args'].items(): + if key not in spectral_kwargs: + spectral_kwargs.update({key: value}) + + ts1 = coh_dict['timeseries1'] + ts2 = coh_dict['timeseries2'] + + spec1 = ts1.spectral(label=ts1.label, **spectral_kwargs) + spec2 = ts2.spectral(label=ts2.label, **spectral_kwargs) + + spec1.plot(ax=ax,**spec1_plot_kwargs) + spec2.plot(ax=ax,**spec2_plot_kwargs) + + if xlim is not None: + ax.set_xlim(xlim) + if xlabel is not None: + ax.set_xlabel(xlabel) + if psd_y_label is not None: + ax.set_ylabel(psd_y_label) + + ax2 = ax.twinx() + + if coh_line_color is not None: + coh_plot_kwargs.update({'color':coh_line_color}) + if coh_y_label is not None: + ax2.set_ylabel(coh_y_label) + if coh_ylim is not None: + ax2.set_ylim(coh_ylim) + if label is None: + label = self.label + coh_plot_kwargs.update({'label': label}) + + scale = coh_dict['scale'] + + ax2.plot(scale,self.global_coh,**coh_plot_kwargs) + ax2.fill_between(scale, 0, self.global_coh, color=fill_color, alpha=fill_alpha) + ax2.grid(False) + + # plot significance levels if present + if self.signif_qs is not None: + signif_method_label = { + 'ar1sim': 'AR(1) simulations (MoM)', + 'phaseran': 'Phase Randomization', + 'CN': 'Colored Noise' + } + + for i, q in enumerate(self.signif_qs): + ax.plot( + scale, q, + label=f'{signif_method_label[self.signif_method]}, {self.qs[i]} threshold', + color='red', + linestyle='dashed', + linewidth=.8, + ) + + #formatting + if legend: + if len(legend_kwargs) == 0: + ax.legend().set_visible(False) + lines, labels = ax.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax2.legend(lines + lines2, labels + labels2) + else: + lines, labels = ax.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + if 'handles' not in legend_kwargs: + legend_kwargs.update({'handles': lines+lines2}) + if 'labels' not in legend_kwargs: + legend_kwargs.update({'labels': labels+labels2}) + ax.legend(**legend_kwargs) + ax2.legend().set_visible(False) + else: + ax.legend().set_visible(False) + ax2.legend().set_visible(False) + + if 'fig' in locals(): + if 'path' in savefig_settings: + plotting.savefig(fig, settings=savefig_settings) + return fig, ax + else: + return ax diff --git a/pyleoclim/core/series.py b/pyleoclim/core/series.py index 4d5f9951..3d1e67ca 100644 --- a/pyleoclim/core/series.py +++ b/pyleoclim/core/series.py @@ -17,15 +17,15 @@ from ..utils import causality as causalutils from ..utils import decomposition from ..utils import filter as filterutils +from ..utils import lipdutils from ..core.psds import PSD from ..core.ssares import SsaRes from ..core.multipleseries import MultipleSeries from ..core.scalograms import Scalogram -from ..core.coherence import Coherence +from ..core.coherences import Coherence, GlobalCoherence from ..core.corr import Corr from ..core.resolutions import Resolution -from .globalcoherence import GlobalCoherence import seaborn as sns import matplotlib.pyplot as plt diff --git a/pyleoclim/tests/test_core_Coherence.py b/pyleoclim/tests/test_core_Coherences.py similarity index 72% rename from pyleoclim/tests/test_core_Coherence.py rename to pyleoclim/tests/test_core_Coherences.py index 07f50419..1639b3db 100644 --- a/pyleoclim/tests/test_core_Coherence.py +++ b/pyleoclim/tests/test_core_Coherences.py @@ -85,4 +85,40 @@ def test_phasestats_t0(self, gen_ts): ts1 = gen_ts ts2 = gen_ts coh = ts2.wavelet_coherence(ts1) - _ = coh.phase_stats(scales=[2,8]) \ No newline at end of file + _ = coh.phase_stats(scales=[2,8]) + +class TestUiGlobalCoherencePlot: + ''' Tests for GlobalCoherence.plot() + ''' + + def test_plot_t0(self, gen_ts): + ''' Test GlobalCoherence.plot with various parameters + ''' + ts1 = gen_ts + ts2 = gen_ts + coh = ts1.global_coherence(ts2) + fig,ax = coh.plot() + pyleo.closefig(fig) + + def test_plot_t1(self, gen_ts): + ''' Test GlobalCoherence.plot with signif tests + ''' + ts1 = gen_ts + ts2 = gen_ts + coh = ts1.global_coherence(ts2).signif_test(number=1) + fig,ax = coh.plot() + pyleo.closefig(fig) + +class TestUiGlobalCoherenceSignifTest: + ''' Tests for GlobalCoherence.signif_test() + ''' + + @pytest.mark.parametrize('method',['ar1sim','phaseran','CN']) + @pytest.mark.parametrize('number',[1,10]) + @pytest.mark.parametrize('qs',[[.95],[.05,.95]]) + def test_signiftest_t0(self,method,number, qs,gen_ts): + ''' Test GlobalCoherence.signif_test + ''' + ts1 = gen_ts + ts2 = gen_ts + _ = ts1.global_coherence(ts2).signif_test(method=method,number=number,qs=qs) \ No newline at end of file diff --git a/pyleoclim/tests/test_core_GeoSeries.py b/pyleoclim/tests/test_core_GeoSeries.py index 281c61f1..01c5f412 100644 --- a/pyleoclim/tests/test_core_GeoSeries.py +++ b/pyleoclim/tests/test_core_GeoSeries.py @@ -45,6 +45,7 @@ def multiple_pinkgeoseries(nrecs = 20, seed = 108, geobox=[-85.0,85.0,-180,180]) return pyleo.MultipleGeoSeries(ts_list, label='Multiple Pink GeoSeries') +@pytest.mark.xfail # will fail until pandas is fixed class TestUIGeoSeriesResample(): ''' test GeoSeries.Resample() ''' diff --git a/pyleoclim/tests/test_core_GlobalCoherence.py b/pyleoclim/tests/test_core_GlobalCoherence.py deleted file mode 100644 index ae4e366d..00000000 --- a/pyleoclim/tests/test_core_GlobalCoherence.py +++ /dev/null @@ -1,53 +0,0 @@ -''' Tests for pyleoclim.core.globalcoherence.GlobalCoherence - -Naming rules: -1. class: Test{filename}{Class}{method} with appropriate camel case -2. function: test_{method}_t{test_id} - -Notes on how to test: -0. Make sure [pytest](https://docs.pytest.org) has been installed: `pip install pytest` -1. execute `pytest {directory_path}` in terminal to perform all tests in all testing files inside the specified directory -2. execute `pytest {file_path}` in terminal to perform all tests in the specified file -3. execute `pytest {file_path}::{TestClass}::{test_method}` in terminal to perform a specific test class/method inside the specified file -4. after `pip install pytest-xdist`, one may execute "pytest -n 4" to test in parallel with number of workers specified by `-n` -5. for more details, see https://docs.pytest.org/en/stable/usage.html -''' - -import pytest -import pyleoclim as pyleo - -class TestUiGlobalCoherencePlot: - ''' Tests for GlobalCoherence.plot() - ''' - - def test_plot_t0(self, gen_ts): - ''' Test GlobalCoherence.plot with various parameters - ''' - ts1 = gen_ts - ts2 = gen_ts - coh = ts1.global_coherence(ts2) - fig,ax = coh.plot() - pyleo.closefig(fig) - - def test_plot_t1(self, gen_ts): - ''' Test GlobalCoherence.plot with signif tests - ''' - ts1 = gen_ts - ts2 = gen_ts - coh = ts1.global_coherence(ts2).signif_test(number=1) - fig,ax = coh.plot() - pyleo.closefig(fig) - -class TestUiGlobalCoherenceSignifTest: - ''' Tests for GlobalCoherence.signif_test() - ''' - - @pytest.mark.parametrize('method',['ar1sim','phaseran','CN']) - @pytest.mark.parametrize('number',[1,10]) - @pytest.mark.parametrize('qs',[[.95],[.05,.95]]) - def test_signiftest_t0(self,method,number, qs,gen_ts): - ''' Test GlobalCoherence.signif_test - ''' - ts1 = gen_ts - ts2 = gen_ts - _ = ts1.global_coherence(ts2).signif_test(method=method,number=number,qs=qs) \ No newline at end of file