Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coherence time axis #501

Merged
merged 5 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 23 additions & 37 deletions pyleoclim/core/coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
"""
from ..utils import plotting
from ..utils import wavelet as waveutils
from ..utils import lipdutils

from ..core.scalograms import Scalogram, MultipleScalogram

import matplotlib.pyplot as plt
Expand All @@ -22,28 +20,6 @@
from scipy.stats.mstats import mquantiles
import warnings

def infer_period_unit_from_time_unit(time_unit):
''' infer a period unit based on the given time unit

'''
if time_unit is None:
period_unit = None
else:
unit_group = lipdutils.timeUnitsCheck(time_unit)
if unit_group != 'unknown':
if unit_group == 'kage_units':
period_unit = 'kyrs'
else:
period_unit = 'yrs'
else:
if time_unit[-1] == 's':
period_unit = time_unit
else:
period_unit = f'{time_unit}s'

return period_unit


class Coherence:
'''Coherence object, meant to receive the WTC and XWT part of Series.wavelet_coherence()

Expand Down Expand Up @@ -85,9 +61,9 @@ def __init__(self, frequency, scale, time, wtc, xwt, phase, coi=None,
if scale_unit is not None:
self.scale_unit = scale_unit
elif timeseries1 is not None:
self.scale_unit = infer_period_unit_from_time_unit(timeseries1.time_unit)
self.scale_unit = plotting.infer_period_unit_from_time_unit(timeseries1.time_unit)
elif timeseries2 is not None:
self.scale_unit = infer_period_unit_from_time_unit(timeseries2.time_unit)
self.scale_unit = plotting.infer_period_unit_from_time_unit(timeseries2.time_unit)
else:
self.scale_unit = None

Expand Down Expand Up @@ -443,7 +419,7 @@ def plot(self, var='wtc', xlabel=None, ylabel=None, title='auto', figsize=[10, 8
return ax


def dashboard(self, title=None, figsize=[9,12], phase_style = {},
def dashboard(self, title=None, figsize=[9,12], overlap = True, phase_style = {},
line_colors = ['tab:blue','tab:orange'], savefig_settings={},
ts_plot_kwargs = None, wavelet_plot_kwargs= None):
''' Cross-wavelet dashboard, including the two series, WTC and XWT.
Expand All @@ -461,6 +437,19 @@ def dashboard(self, title=None, figsize=[9,12], phase_style = {},

Figure size. The default is [9, 12], as this is an information-rich figure.

overlap : boolean, optional
whether to restrict the plot to the period of overlap between the series. Defaults to True

phase_style : dict, optional

Arguments for the phase arrows. The default is {}. It includes:
- 'pt': the default threshold above which phase arrows will be plotted
- 'skip_x': the number of points to skip between phase arrows along the x-axis
- 'skip_y': the number of points to skip between phase arrows along the y-axis
- 'scale': number of data units per arrow length unit (see matplotlib.pyplot.quiver)
- 'width': shaft width in arrow units (see matplotlib.pyplot.quiver)
- 'color': arrow color (see matplotlib.pyplot.quiver)

line_colors : list, optional

Colors for the 2 traces For nomenclature, see https://matplotlib.org/stable/gallery/color/named_colors.html
Expand All @@ -473,16 +462,6 @@ def dashboard(self, title=None, figsize=[9,12], phase_style = {},
with or without a suffix; if the suffix is not given in "path", it will follow "format"
- "format" can be one of {"pdf", "eps", "png", "ps"}

phase_style : dict, optional

Arguments for the phase arrows. The default is {}. It includes:
- 'pt': the default threshold above which phase arrows will be plotted
- 'skip_x': the number of points to skip between phase arrows along the x-axis
- 'skip_y': the number of points to skip between phase arrows along the y-axis
- 'scale': number of data units per arrow length unit (see matplotlib.pyplot.quiver)
- 'width': shaft width in arrow units (see matplotlib.pyplot.quiver)
- 'color': arrow color (see matplotlib.pyplot.quiver)

ts_plot_kwargs : dict

arguments to be passed to the timeseries subplot, see pyleoclim.core.series.Series.plot for details
Expand Down Expand Up @@ -547,6 +526,9 @@ def dashboard(self, title=None, figsize=[9,12], phase_style = {},
gs = gridspec.GridSpec(8, 1)
gs.update(wspace=0, hspace=0.5) # add some breathing room
ax = {}

# assess period of overlap
xlims = np.min(self.time), np.max(self.time)

# 1) plot timeseries
#plt.rc('ytick', labelsize=8)
Expand All @@ -558,6 +540,8 @@ def dashboard(self, title=None, figsize=[9,12], phase_style = {},
ax['ts1'].spines['bottom'].set_visible(False)
ax['ts1'].grid(False)
ax['ts1'].set_xlabel('')
if overlap:
ax['ts1'].set_xlim(xlims)

ax['ts2'] = ax['ts1'].twinx()
self.timeseries2.plot(ax=ax['ts2'], color=line_colors[1], **ts_plot_kwargs, legend=False)
Expand All @@ -567,6 +551,8 @@ def dashboard(self, title=None, figsize=[9,12], phase_style = {},
ax['ts2'].spines['right'].set_visible(True)
ax['ts2'].spines['left'].set_visible(False)
ax['ts2'].grid(False)
if overlap:
ax['ts2'].set_xlim(xlims)

# 2) plot WTC
ax['wtc'] = plt.subplot(gs[2:5, 0], sharex=ax['ts1'])
Expand Down
7 changes: 4 additions & 3 deletions pyleoclim/core/multipleseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ class MultipleSeries:

'''
def __init__(self, series_list, time_unit=None, label=None, name=None):
from ..core.series import Series
from ..core.geoseries import GeoSeries
from ..core.lipdseries import LipdSeries

self.series_list = series_list
self.time_unit = time_unit
Expand All @@ -75,6 +72,10 @@ def __init__(self, series_list, time_unit=None, label=None, name=None):
warnings.warn("`name` is a deprecated property, which will be removed in future releases. Please use `label` instead.",
DeprecationWarning, stacklevel=2)
# check that all components are Series
from ..core.series import Series
from ..core.geoseries import GeoSeries
from ..core.lipdseries import LipdSeries

if not all([isinstance(ts, (Series, GeoSeries, LipdSeries)) for ts in self.series_list]):
raise ValueError('All components must be of the same type')

Expand Down
23 changes: 1 addition & 22 deletions pyleoclim/core/scalograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,6 @@
from scipy.stats.mstats import mquantiles

#from ..core import MultipleScalogram
def infer_period_unit_from_time_unit(time_unit):
''' infer a period unit based on the given time unit

'''
if time_unit is None:
period_unit = None
else:
unit_group = lipdutils.timeUnitsCheck(time_unit)
if unit_group != 'unknown':
if unit_group == 'kage_units':
period_unit = 'kyrs'
else:
period_unit = 'yrs'
else:
if time_unit[-1] == 's':
period_unit = time_unit
else:
period_unit = f'{time_unit}s'

return period_unit


class Scalogram:
'''
Expand Down Expand Up @@ -159,7 +138,7 @@ def __init__(self, frequency, scale, time, amplitude, coi=None, label=None, Neff
if scale_unit is not None:
self.scale_unit = scale_unit
elif timeseries is not None:
self.scale_unit = infer_period_unit_from_time_unit(timeseries.time_unit)
self.scale_unit = plotting.infer_period_unit_from_time_unit(timeseries.time_unit)
else:
self.scale_unit = None

Expand Down
35 changes: 29 additions & 6 deletions pyleoclim/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class Series:
Defaults to 'ascending'

verbose : bool
If True, will print warning messages if there is any
If True, will print warning messages if there are any

clean_ts : boolean flag
set to True to remove the NaNs and make time axis strictly prograde with duplicated timestamps reduced by averaging the values
Expand Down Expand Up @@ -151,8 +151,9 @@ def __init__(self, time, value, time_unit=None, time_name=None,
value = np.array(value)

if auto_time_params is None:
warnings.warn('auto_time_params is not specified. Currently default behavior sets this to True. In a future release, this will be changed to False.', UserWarning, stacklevel=2)
auto_time_params = True
if verbose:
warnings.warn('auto_time_params is not specified. Currently default behavior sets this to True, which might modify your supplied time metadata. Please set to False if you want a different behavior.', UserWarning, stacklevel=2)

if auto_time_params:
# assign time metadata if they are not provided or provided incorrectly
Expand Down Expand Up @@ -3200,7 +3201,21 @@ def wavelet_coherence(self, target_series, method='cwt', settings=None,
coh_wwz.plot()

As with wavelet analysis, both CWT and WWZ admit optional arguments through `settings`.
Significance is assessed similarly as with PSD or Scalogram objects:
For instance, one can adjust the resolution of the time axis on which coherence is evaluated:

.. jupyter-execute::

coh_wwz = ts_air.wavelet_coherence(ts_nino, method = 'wwz', settings = {'ntau':20})
coh_wwz.plot()

The frequency (scale) axis can also be customized, e.g. to focus on scales from 1 to 20y, with 24 scales:

.. jupyter-execute::

coh = ts_air.wavelet_coherence(ts_nino, freq_kwargs={'fmin':1/20,'fmax':1,'nf':24})
coh.plot()

Significance is assessed similarly to PSD or Scalogram objects:

.. jupyter-execute::

Expand All @@ -3218,6 +3233,8 @@ def wavelet_coherence(self, target_series, method='cwt', settings=None,
cwt_sig.dashboard()

Note: this design balances many considerations, and is not easily customizable.


'''
if not verbose:
warnings.simplefilter('ignore')
Expand All @@ -3234,8 +3251,9 @@ def wavelet_coherence(self, target_series, method='cwt', settings=None,
freq_kwargs = {} if freq_kwargs is None else freq_kwargs.copy()
freq = specutils.make_freq_vector(self.time, method=freq_method, **freq_kwargs)
args = {}
args['wwz'] = {'freq': freq}
args['wwz'] = {'freq': freq, 'verbose': verbose}
args['cwt'] = {'freq': freq}


# put on same time axes if necessary
if method == 'cwt' and not np.array_equal(self.time, target_series.time):
Expand All @@ -3261,12 +3279,17 @@ def wavelet_coherence(self, target_series, method='cwt', settings=None,
else:
ntau = np.min([np.size(ts1.time), np.size(ts2.time), 50])

tau = np.linspace(np.min(self.time), np.max(self.time), ntau)
if 'tau' in settings.keys():
tau = settings['tau']
else:
lb1, ub1 = np.min(ts1.time), np.max(ts1.time)
lb2, ub2 = np.min(ts2.time), np.max(ts2.time)
lb = np.max([lb1, lb2])
ub = np.min([ub1, ub2])

tau = np.linspace(lb, ub, ntau)
settings.update({'tau': tau})


args[method].update(settings)

# Apply WTC method
Expand Down
7 changes: 7 additions & 0 deletions pyleoclim/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def metadata():
)
}

@pytest.fixture
def gen_ts():
""" Generate realistic-ish Series for testing """
t,v = pyleo.utils.gen_ts(model='colored_noise',nt=50)
ts = pyleo.Series(t,v, verbose=False)
return ts

@pytest.fixture
def unevenly_spaced_series():
"""Pyleoclim series with unevenly spaced time axis"""
Expand Down
52 changes: 19 additions & 33 deletions pyleoclim/tests/test_core_Coherence.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,37 @@
import pytest
import pyleoclim as pyleo


def gen_ts(model,nt):
'wrapper for gen_ts in pyleoclim'

t,v = pyleo.utils.gen_ts(model=model,nt=nt)
ts=pyleo.Series(t,v)
return ts

# Tests below

class TestUiCoherencePlot:
''' Tests for Coherence.plot()
'''

def test_plot_t0(self):
def test_plot_t0(self, gen_ts):
''' Test Coherence.plot with default parameters
'''
nt = 200
ts1 = gen_ts(model='colored_noise', nt=nt)
ts2 = gen_ts(model='colored_noise', nt=nt)
ts1 = gen_ts
ts2 = gen_ts
coh = ts2.wavelet_coherence(ts1)
fig,ax = coh.plot()
pyleo.closefig(fig)

def test_plot_t1(self):
def test_plot_t1(self, gen_ts):
''' Test Coherence.plot WTC with significance
'''
nt = 200
ts1 = gen_ts(model='colored_noise', nt=nt)
ts2 = gen_ts(model='colored_noise', nt=nt)
ts1 = gen_ts
ts2 = gen_ts
coh = ts2.wavelet_coherence(ts1)

coh_signif = coh.signif_test(number=10,qs = [0.8, 0.9, .95])
fig,ax = coh_signif.plot(signif_thresh=0.99)
pyleo.closefig(fig)

def test_plot_t2(self):
def test_plot_t2(self, gen_ts):
''' Test Coherence.plot XWT with significance
'''
nt = 200
ts1 = gen_ts(model='colored_noise', nt=nt)
ts2 = gen_ts(model='colored_noise', nt=nt)
ts1 = gen_ts
ts2 = gen_ts
coh = ts2.wavelet_coherence(ts1)

coh_signif = coh.signif_test(number=10)
Expand All @@ -68,34 +57,31 @@ def test_plot_t2(self):
class TestUiCoherenceDashboard:
''' Tests for Coherence.dashboard()
'''
def test_dashboard_t0(self):
def test_dashboard_t0(self, gen_ts):
''' Test Coherence.dashboard() with default parameters
'''
nt = 200
ts1 = gen_ts(model='colored_noise', nt=nt)
ts2 = gen_ts(model='colored_noise', nt=nt)
ts1 = gen_ts
ts2 = gen_ts
coh = ts2.wavelet_coherence(ts1)
fig,ax = coh.dashboard()
pyleo.closefig(fig)

def test_dashboard_t1(self):
def test_dashboard_t1(self, gen_ts):
''' Test Coherence.dashboard() with optional parameter
'''
nt = 200
ts1 = gen_ts(model='colored_noise', nt=nt)
ts2 = gen_ts(model='colored_noise', nt=nt)
ts1 = gen_ts
ts2 = gen_ts
coh = ts2.wavelet_coherence(ts1)
fig, ax = coh.dashboard(wavelet_plot_kwargs={'contourf_style':{'cmap': 'cividis'}})
pyleo.closefig(fig)

class TestUiCoherencePhaseStats:
''' Tests for Coherence.phase_stats()
'''
def test_phasestats_t0(self):
def test_phasestats_t0(self, gen_ts):
''' Test Coherence.phase_stats() with default parameters
'''
nt = 200
ts1 = gen_ts(model='colored_noise', nt=nt)
ts2 = gen_ts(model='colored_noise', nt=nt)
ts1 = gen_ts
ts2 = gen_ts
coh = ts2.wavelet_coherence(ts1)
phase = coh.phase_stats(scales=[2,8])
_ = coh.phase_stats(scales=[2,8])
Loading