diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..af929af --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,11 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - uses: psf/black@stable \ No newline at end of file diff --git a/.gitignore b/.gitignore index 238b20c..80f4e05 100644 --- a/.gitignore +++ b/.gitignore @@ -213,4 +213,7 @@ $RECYCLE.BIN/ # Windows shortcuts *.lnk -# End of https://www.gitignore.io/api/macos,linux,python,windows,jupyternotebook,visualstudiocode \ No newline at end of file +# End of https://www.gitignore.io/api/macos,linux,python,windows,jupyternotebook,visualstudiocode + +.idea +.vscode \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 7a73a41..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,2 +0,0 @@ -{ -} \ No newline at end of file diff --git a/jointly/__init__.py b/jointly/__init__.py index 959eb68..4a1cd8a 100644 --- a/jointly/__init__.py +++ b/jointly/__init__.py @@ -1,4 +1,4 @@ from .abstract_extractor import * -from .segment_selector import * from .shake_extractor import * +from .segment_selector import * from .synchronizer import * diff --git a/jointly/abstract_extractor.py b/jointly/abstract_extractor.py index 1e74e93..348dd24 100644 --- a/jointly/abstract_extractor.py +++ b/jointly/abstract_extractor.py @@ -2,7 +2,6 @@ class AbstractExtractor(metaclass=abc.ABCMeta): - def __init__(self): self.segments = {} @@ -12,7 +11,7 @@ def get_segments(self, signals): return dictionary with start and end timestamps for each signal. Format of dictionary: - >>> { + >>> { >>> 'column_name': { >>> 'first': { >>> 'start': timestamp, @@ -25,21 +24,21 @@ def get_segments(self, signals): >>> }, >>> ... >>> } - + """ def _init_segments(self, columns): self.segments = {} for column_name in columns: self.segments[column_name] = { - 'first': {}, - 'second': {}, + "first": {}, + "second": {}, } def _set_first_segment(self, column_name, start, end): - self.segments[column_name]['first']['start'] = start - self.segments[column_name]['first']['end'] = end + self.segments[column_name]["first"]["start"] = start + self.segments[column_name]["first"]["end"] = end def _set_second_segment(self, column_name, start, end): - self.segments[column_name]['second']['start'] = start - self.segments[column_name]['second']['end'] = end + self.segments[column_name]["second"]["start"] = start + self.segments[column_name]["second"]["end"] = end diff --git a/jointly/helpers.py b/jointly/helpers.py index 175e466..665989a 100644 --- a/jointly/helpers.py +++ b/jointly/helpers.py @@ -6,6 +6,7 @@ ### Signal processing + def normalize(x): """Normalizes signal to interval [-1, 1] with mean 0.""" xn = x - np.mean(x) @@ -17,19 +18,26 @@ def get_equidistant_signals(signals, frequency): """Returns copy of dataframe with signals sampled equidistantly at the specified frequency. """ - freq = '{}us'.format(int(1e6 / frequency)) - df = pd.DataFrame(index=pd.date_range(start=pd.to_datetime(signals.index.min(), unit='s'), - end=pd.to_datetime(signals.index.max(), unit='s'), - freq=freq)) - df = df.join(signals.copy(), how='outer') - df = df.interpolate(method='time', limit_area='inside').asfreq(freq) + freq = "{}N".format(int(1e9 / frequency)) + df = pd.DataFrame( + { + name: signals[name].dropna().resample(freq).nearest() + for name in signals.columns + }, + index=pd.date_range( + start=pd.to_datetime(signals.index.min(), unit="s"), + end=pd.to_datetime(signals.index.max(), unit="s"), + freq=freq, + ), + ) return df ### Plotting + def plot_signals(df, cols=None, title=None, tags=None): - cmap = matplotlib.cm.get_cmap('tab10') + cmap = matplotlib.cm.get_cmap("tab10") fig, ax = pyplot.subplots(figsize=(15, 6)) if cols is None: cols = df.columns @@ -37,12 +45,12 @@ def plot_signals(df, cols=None, title=None, tags=None): ax.set_title(title) if tags is not None: for tag in tags: - timestamp = pd.to_datetime(tag, unit='s') + timestamp = pd.to_datetime(tag, unit="s") if timestamp > df.index.min() and timestamp < df.index.max(): - ax.axvline(timestamp, color='grey') + ax.axvline(timestamp, color="grey") for index, col in enumerate(cols): ax.plot(df.index, df[col], color=cmap(index), label=col) - ax.set_xlabel('Time') + ax.set_xlabel("Time") fig.tight_layout() pyplot.legend() pyplot.show() @@ -58,10 +66,16 @@ def plot_segments(dataframe, segments, together=True, seperate=True): nrows = len(segment_names) fig, axes = pyplot.subplots(nrows, ncols, figsize=(15, 4 * nrows)) for index, segment_name in enumerate(segment_names): - axes[index].set_title('{} segment'.format(segment_name)) - signals_with_segment = list(filter(lambda x: segment_name in segments[x], signal_names)) - start = np.amin([segments[x][segment_name]['start'] for x in signals_with_segment]) - end = np.amax([segments[x][segment_name]['end'] for x in signals_with_segment]) + axes[index].set_title("{} segment".format(segment_name)) + signals_with_segment = list( + filter(lambda x: segment_name in segments[x], signal_names) + ) + start = np.amin( + [segments[x][segment_name]["start"] for x in signals_with_segment] + ) + end = np.amax( + [segments[x][segment_name]["end"] for x in signals_with_segment] + ) dataframe[start:end].plot(ax=axes[index]) fig.tight_layout() @@ -69,14 +83,17 @@ def plot_segments(dataframe, segments, together=True, seperate=True): # plot signals seperately ncols = len(segment_names) nrows = len(segments.keys()) - cmap = matplotlib.cm.get_cmap('tab10') + cmap = matplotlib.cm.get_cmap("tab10") fig, axes = pyplot.subplots(nrows, ncols, figsize=(15, 4 * nrows)) for index, signal_name in enumerate(segments.keys()): for index_seg, segment_name in enumerate(segment_names): if segment_name not in segments[signal_name]: continue - axes[index, index_seg].set_title('{} segment of {}'.format(segment_name, signal_name)) + axes[index, index_seg].set_title( + "{} segment of {}".format(segment_name, signal_name) + ) segment = segments[signal_name][segment_name] - dataframe[signal_name][segment['start']:segment['end']] \ - .plot(ax=axes[index, index_seg], color=cmap(index)) + dataframe[signal_name][segment["start"] : segment["end"]].plot( + ax=axes[index, index_seg], color=cmap(index) + ) fig.tight_layout() diff --git a/jointly/log.py b/jointly/log.py index 827a27e..4811cc7 100644 --- a/jointly/log.py +++ b/jointly/log.py @@ -2,5 +2,5 @@ FORMAT = "[%(lineno)3s - %(funcName)20s() ] %(message)s" logging.basicConfig(format=FORMAT) -logger = logging.getLogger('jointly') +logger = logging.getLogger("jointly") logger.setLevel(logging.CRITICAL) diff --git a/jointly/segment_selector.py b/jointly/segment_selector.py index c985a14..44686f7 100644 --- a/jointly/segment_selector.py +++ b/jointly/segment_selector.py @@ -10,9 +10,10 @@ def __init__(self, signals, segments=None): if segments is None: self.segments = { signal: { - 'first': {}, - 'second': {}, - } for signal in signals.columns + "first": {}, + "second": {}, + } + for signal in signals.columns } self._display_plots() @@ -23,13 +24,13 @@ def _display_plots(self): for index, name in enumerate(self.signals.columns): axes[index].set_title(name) - axes[index].plot(self.signals[name].interpolate(method='time').values) + axes[index].plot(self.signals[name].interpolate(method="time").values) fig.tight_layout() - bp_id = fig.canvas.mpl_connect('button_press_event', self._on_click) + bp_id = fig.canvas.mpl_connect("button_press_event", self._on_click) def _display_segments(self): - print('Do') + print("Do") def _on_click(self, event): # if not event.dblclick: @@ -40,8 +41,8 @@ def _on_click(self, event): title = event.inaxes.title.get_text() index = int(event.xdata) - for segment in ['first', 'second']: - for time in ['start', 'end']: + for segment in ["first", "second"]: + for time in ["start", "end"]: if time not in self.segments[title][segment]: self.segments[title][segment][time] = self.signals.index[index] break diff --git a/jointly/shake_extractor.py b/jointly/shake_extractor.py index a96c057..f22ffc8 100644 --- a/jointly/shake_extractor.py +++ b/jointly/shake_extractor.py @@ -1,26 +1,39 @@ -import pprint +from typing import List, Tuple import numpy as np import pandas as pd -import scipy.interpolate import scipy.signal +import scipy.interpolate +import pprint from .abstract_extractor import AbstractExtractor from .log import logger +from .synchronization_errors import ( + BadThresholdException, + BadWindowException, + ShakeMissingException, +) pp = pprint.PrettyPrinter() class ShakeExtractor(AbstractExtractor): - window = 600 - """time window in seconds in which to look for peaks from start and end of signal""" + start_window_length = pd.Timedelta(seconds=600) + """time window in seconds in which to look for peaks from start of signal""" + + end_window_length = pd.Timedelta(seconds=600) + """time window in seconds in which to look for peaks at end of signal""" + threshold = 0.3 - """threshold for peak detection""" + """min height for peak detection. In range [0, 1], as the data is normalized.""" + distance = 1500 """distance in milliseconds in which the next peak must occur to be considered a sequence""" + min_length = 6 """minimum number of peaks per sequence""" - time_buffer = 1 + + time_buffer = pd.Timedelta(seconds=1) """time in seconds will be padded to first and last peak for timestamps of segment""" def get_shake_weight(self, x): @@ -32,17 +45,26 @@ def get_peak_sequences(self, signals, column, start_window, end_window): Sequences with a length less than min_length peaks are filtered. """ sequences = [] - logger.debug('Use peak threshold {}'.format(self.threshold)) + if not (0 <= self.threshold <= 1): + raise BadThresholdException( + "Threshold must be a value in [0, 1]. Data is normalized!" + ) + + logger.debug("Use peak threshold {}".format(self.threshold)) start_part = signals[column].truncate(after=start_window) - peaks_start, _properties = scipy.signal.find_peaks(start_part, height=self.threshold) + peaks_start, _properties = scipy.signal.find_peaks( + start_part, height=self.threshold + ) end_part = signals[column].truncate(before=end_window) - peaks_end, _properties = scipy.signal.find_peaks(end_part, height=self.threshold) + peaks_end, _properties = scipy.signal.find_peaks( + end_part, height=self.threshold + ) peaks_end = peaks_end + signals.index.get_loc(end_part.index[0]) peaks = [*peaks_start, *peaks_end] - logger.debug('Found {} peaks for {}'.format(len(peaks), column)) + logger.debug("Found {} peaks for {}".format(len(peaks), column)) for pos, index in enumerate(peaks): row = signals.iloc[[index]] @@ -59,45 +81,98 @@ def get_peak_sequences(self, signals, column, start_window, end_window): else: # start new sequence sequences[len(sequences) - 1].append(row.index) - logger.debug('Merged peaks within {} ms to {} sequences for {}'.format(self.distance, len(sequences), column)) + logger.debug( + "Merged peaks within {} ms to {} sequences for {}".format( + self.distance, len(sequences), column + ) + ) # filter sequences with less than min_length peaks - sequences_filtered = list(filter(lambda x: len(x) >= self.min_length, sequences)) - logger.debug('{} sequences did satisfy minimum length of {} for {}'.format(len(sequences_filtered), self.min_length, column)) + sequences_filtered = list( + filter(lambda x: len(x) >= self.min_length, sequences) + ) + logger.debug( + "{} sequences did satisfy minimum length of {} for {}".format( + len(sequences_filtered), self.min_length, column + ) + ) return sequences_filtered + def _choose_sequence(self, signal, shake_list: List) -> Tuple: + if len(shake_list) > 0: + first = max(shake_list, key=self.get_shake_weight) + segment_start_time = first[0].index[0] - self.time_buffer + segment_start_index = signal.index.get_loc( + segment_start_time, method="nearest" + ) + start = signal.index[segment_start_index] + + segment_end_time = first[-1].index[0] + self.time_buffer + segment_end_index = signal.index.get_loc(segment_end_time, method="nearest") + end = signal.index[segment_end_index] + + return start, end + else: + raise ShakeMissingException(f"No shakes detected") + def get_segments(self, signals): """Returns dictionary with timestamps, that mark start and end of each shake segment.""" columns = list(signals.columns) self._init_segments(columns) - time_buffer = pd.Timedelta(seconds=self.time_buffer) # will be added to start and subtracted from end of sequence - window = pd.Timedelta(seconds=self.window) # time window from start and end in which to look for sequences + # will be added to start and subtracted from end of sequence for column in columns: - start_window = signals[column].first_valid_index() + window - end_window = signals[column].last_valid_index() - window + last_timestamp = signals[column].last_valid_index() + first_timestamp = signals[column].first_valid_index() + duration = last_timestamp - first_timestamp + if duration < self.start_window_length or duration < self.end_window_length: + raise BadWindowException( + f"The window is longer than signal {column}. Make it so the window only covers start and end, not both." + ) + + start_window = first_timestamp + self.start_window_length + end_window = signals[column].last_valid_index() - self.end_window_length peaks = self.get_peak_sequences(signals, column, start_window, end_window) # map peak indices to their values - shakes = list(map(lambda sequence: (list(map(lambda index: signals[column][index], sequence))), peaks)) + shakes = list( + map( + lambda sequence: ( + list(map(lambda index: signals[column][index], sequence)) + ), + peaks, + ) + ) # select sequences in start/end window - shakes_first = list(filter(lambda sequence: sequence[0].index[0] < start_window, shakes)) - logger.debug('{} shakes in before {} for {}.'.format(len(shakes_first), start_window, column)) - shakes_second = list(filter(lambda sequence: sequence[-1].index[0] > end_window, shakes)) - logger.debug('{} shakes in after {} for {}.'.format(len(shakes_second), end_window, column)) + shakes_first = list( + filter(lambda sequence: sequence[0].index[0] < start_window, shakes) + ) + logger.debug( + "{} shakes in before {} for {}.".format( + len(shakes_first), start_window, column + ) + ) + shakes_second = list( + filter(lambda sequence: sequence[-1].index[0] > end_window, shakes) + ) + logger.debug( + "{} shakes in after {} for {}.".format( + len(shakes_second), end_window, column + ) + ) # choose sequence with highest weight - if len(shakes_first) > 0: - first = max(shakes_first, key=self.get_shake_weight) - start = first[0].index[0] - time_buffer - end = first[-1].index[0] + time_buffer - self._set_first_segment(column, start, end) - if len(shakes_second) > 0: - second = max(shakes_second, key=self.get_shake_weight) - start = second[0].index[0] - time_buffer - end = second[-1].index[0] + time_buffer - self._set_second_segment(column, start, end) - logger.info('Shake segments for {}:\n{}'.format(column, pp.pformat(self.segments[column]))) + start, end = self._choose_sequence(signals[column], shakes_first) + self._set_first_segment(column, start, end) + + start, end = self._choose_sequence(signals[column], shakes_second) + self._set_second_segment(column, start, end) + + logger.info( + "Shake segments for {}:\n{}".format( + column, pp.pformat(self.segments[column]) + ) + ) return self.segments diff --git a/jointly/synchronization_errors.py b/jointly/synchronization_errors.py new file mode 100644 index 0000000..5916cd8 --- /dev/null +++ b/jointly/synchronization_errors.py @@ -0,0 +1,14 @@ +class ShakeMissingException(Exception): + pass + + +class BadThresholdException(Exception): + pass + + +class StartEqualsEndError(Exception): + pass + + +class BadWindowException(Exception): + pass diff --git a/jointly/synchronizer.py b/jointly/synchronizer.py index 9a46de2..a25eee3 100644 --- a/jointly/synchronizer.py +++ b/jointly/synchronizer.py @@ -1,14 +1,17 @@ -import logging import os +import logging +from pprint import pprint +from typing import Dict +import scipy.signal import numpy as np import pandas as pd -import scipy.signal from matplotlib import pyplot -from .abstract_extractor import AbstractExtractor -from .helpers import normalize, get_equidistant_signals from .log import logger +from .helpers import normalize, get_equidistant_signals +from .abstract_extractor import AbstractExtractor +from .synchronization_errors import ShakeMissingException, StartEqualsEndError class Synchronizer: @@ -19,10 +22,12 @@ def extractor(self): @extractor.setter def extractor(self, value): if not issubclass(type(value), AbstractExtractor): - raise TypeError('Extractor needs to be a subclass of AbstractExtractor.') + raise TypeError("Extractor needs to be a subclass of AbstractExtractor.") self._extractor = value - def __init__(self, sources, ref_source_name, extractor, sampling_freq=None, tags=None): + def __init__( + self, sources, ref_source_name, extractor, sampling_freq=None, tags=None + ): self.sources = sources self.ref_source_name = ref_source_name self.extractor = extractor @@ -42,20 +47,25 @@ def truncate_data(self, buffer=300): self.ref_signals = self.ref_signals.truncate(before=before, after=after) for source in self.sources.values(): - source['data'] = source['data'].truncate(before=before, after=after) + source["data"] = source["data"].truncate(before=before, after=after) def _prepare_ref_signals(self): ref_signals = pd.DataFrame() for source_name, source in self.sources.items(): - signal = source['data'][source['ref_column']].dropna() - ref_signals = ref_signals.join(signal, how='outer') - ref_signals.rename(columns=(lambda x: source_name if x == source['ref_column'] else x), inplace=True) + signal = source["data"][source["ref_column"]].dropna() + ref_signals = ref_signals.join(signal, how="outer") + ref_signals.rename( + columns=(lambda x: source_name if x == source["ref_column"] else x), + inplace=True, + ) ref_signals = ref_signals.apply(normalize) return ref_signals def get_max_ref_frequency(self): if self.ref_signals is None: - raise ValueError('Unable to get maximum frequency: Reference signals undefined.') + raise ValueError( + "Unable to get maximum frequency: Reference signals undefined." + ) frequencies = self.ref_signals.aggregate(Synchronizer._infer_freq) return np.amax(frequencies) @@ -64,7 +74,7 @@ def _infer_freq(series): index = series.dropna().index timedeltas = index[1:] - index[:-1] median = np.median(timedeltas) - return np.timedelta64(1, 's') / median + return np.timedelta64(1, "s") / median @staticmethod def _stretch_signals(source, factor, start_time=None): @@ -72,7 +82,7 @@ def _stretch_signals(source, factor, start_time=None): df = source.copy() if start_time is None: start_time = df.index.min() - logger.debug('Use start time: {}'.format(start_time)) + logger.debug("Use start time: {}".format(start_time)) timedelta = df.index - start_time new_index = timedelta * factor + start_time df.set_index(new_index, inplace=True, verify_integrity=True) @@ -80,8 +90,8 @@ def _stretch_signals(source, factor, start_time=None): @staticmethod def _get_stretch_factor(segments, timeshifts): - old_length = segments['second']['start'] - segments['first']['start'] - new_length = old_length + timeshifts['second'] - timeshifts['first'] + old_length = segments["second"]["start"] - segments["first"]["start"] + new_length = old_length + timeshifts["second"] - timeshifts["first"] stretch_factor = new_length / old_length return stretch_factor @@ -92,54 +102,92 @@ def _get_timeshifts(dataframe, columns, segments): Expects equidistant sampled signals. """ timeshifts = {} - segment_names = ['first', 'second'] + segment_names = ["first", "second"] ref_column = columns[0] sig_column = columns[1] + fig, axes = None, None if logger.isEnabledFor(logging.INFO): fig, axes = pyplot.subplots(1, 2, figsize=(15, 4)) + # Check that all segments are available + for col in columns: + for segment in segment_names: + for part in ["start", "end"]: + try: + segments[col][segment][part] + except KeyError: + print("Dumping all detected segments:") + pprint(segments) + raise ShakeMissingException( + f"No {segment} shake detected for {col}, missing the {part}" + ) + for index, segment in enumerate(segment_names): - logger.debug('Calculate timeshift of {} segment for {} to {}.'.format(segment, sig_column, ref_column)) + logger.debug( + "Calculate timeshift of {} segment for {} to {}.".format( + segment, sig_column, ref_column + ) + ) # get segments from both signals - ref_start = segments[ref_column][segment]['start'] - ref_end = segments[ref_column][segment]['end'] + ref_start = segments[ref_column][segment]["start"] + ref_end = segments[ref_column][segment]["end"] ref_segment = dataframe[ref_column][ref_start:ref_end] - sig_start = segments[sig_column][segment]['start'] - sig_end = segments[sig_column][segment]['end'] + sig_start = segments[sig_column][segment]["start"] + sig_end = segments[sig_column][segment]["end"] sig_segment = dataframe[sig_column][sig_start:sig_end] # calculate cross-correlation of segments cross_corr = scipy.signal.correlate(ref_segment, sig_segment) # get shift in samples - shift_in_samples = np.argmax(cross_corr) - len(sig_segment) - 1 + shift_in_samples = np.argmax(cross_corr) - len(sig_segment) + 1 # get timestamp at which sig_segment must start to sync signals - max_corr_ts = dataframe.index[dataframe.index.get_loc(ref_start, method='nearest') + shift_in_samples] - logger.debug('Highest correlation with start at {} with {}.'.format(max_corr_ts, np.max(cross_corr))) + max_corr_ts = dataframe.index[ + dataframe.index.get_loc(ref_start, method="nearest") + shift_in_samples + ] + logger.debug( + "Highest correlation with start at {} with {}.".format( + max_corr_ts, np.max(cross_corr) + ) + ) # calculate timeshift to move signal to maximize correlation timeshifts[segment] = max_corr_ts - sig_start - logger.debug('Timeshift is {}.'.format(str(timeshifts[segment]))) + logger.debug("Timeshift is {}.".format(str(timeshifts[segment]))) # plot shifted segments if logger.isEnabledFor(logging.INFO): - df = dataframe.copy() - df[sig_column] = df[sig_column].shift(1, freq=timeshifts[segment]) - axes[index].set_title('{} segment of {c[0]} and {c[1]}'.format(segment, c=columns)) - df[columns][ref_start:ref_end].plot(ax=axes[index]) + try: + df = dataframe.copy() + df[sig_column] = df[sig_column].shift(1, freq=timeshifts[segment]) + if axes is not None: + axes[index].set_title( + "{} segment of {c[0]} and {c[1]}".format(segment, c=columns) + ) + df[columns][ref_start:ref_end].plot(ax=axes[index]) + except MemoryError: + logger.warn( + f"Couldn't allocate enough memory to plot shifted segments, skipping" + ) if logger.isEnabledFor(logging.INFO): - fig.tight_layout() + try: + if fig is not None: + fig.tight_layout() + except MemoryError: + logger.warn( + f"Couldn't allocate enough memory to plot shifted segments, skipping" + ) return timeshifts def _calculate_sync_params(self): dataframe = self.ref_signals.copy() start_time = self.ref_signals.index.min() - self.sources[self.ref_source_name]['timeshift'] = None - self.sources[self.ref_source_name]['stretch_factor'] = 1 + self.sources[self.ref_source_name]["timeshift"] = None + self.sources[self.ref_source_name]["stretch_factor"] = 1 # Interpolate and resample to equidistant signal df_equi = get_equidistant_signals(self.ref_signals, self.sampling_freq) @@ -150,14 +198,39 @@ def _calculate_sync_params(self): if column == self.ref_source_name: continue else: - timeshifts = Synchronizer._get_timeshifts(df_equi, [self.ref_source_name, column], segments) - logger.debug('Timedelta between shifts before stretching: {}'.format(timeshifts['first'] - timeshifts['second'])) - self.sources[column]['stretch_factor'] = Synchronizer._get_stretch_factor(segments[column], timeshifts) - logger.info('Stretch factor for {}: {}'.format(column, self.sources[column]['stretch_factor'])) + timeshifts = Synchronizer._get_timeshifts( + df_equi, [self.ref_source_name, column], segments + ) + logger.debug( + "Timedelta between shifts before stretching: {}".format( + timeshifts["first"] - timeshifts["second"] + ) + ) + try: + self.sources[column][ + "stretch_factor" + ] = Synchronizer._get_stretch_factor(segments[column], timeshifts) + except ZeroDivisionError: + raise StartEqualsEndError( + "First and last segment have been identified as exactly the same. Bad window, maybe?" + ) + logger.info( + "Stretch factor for {}: {}".format( + column, self.sources[column]["stretch_factor"] + ) + ) # stretch signal and exchange it in dataframe - signal_stretched = Synchronizer._stretch_signals(pd.DataFrame(dataframe[column]), self.sources[column]['stretch_factor'], start_time) - dataframe = dataframe.drop(column, axis='columns').join(signal_stretched, how='outer') + signal_stretched = Synchronizer._stretch_signals( + pd.DataFrame(dataframe[column]).dropna(), + self.sources[column]["stretch_factor"], + start_time, + ) + dataframe = ( + dataframe.drop(column, axis="columns") + .join(signal_stretched, how="outer") + .astype(pd.SparseDtype("float")) + ) # Resample again with stretched signal df_equi = get_equidistant_signals(dataframe, self.sampling_freq) @@ -168,47 +241,77 @@ def _calculate_sync_params(self): if column == self.ref_source_name: continue else: - timeshifts = Synchronizer._get_timeshifts(df_equi, [self.ref_source_name, column], segments) - timedelta = timeshifts['first'] - timeshifts['second'] + timeshifts = Synchronizer._get_timeshifts( + df_equi, [self.ref_source_name, column], segments + ) + timedelta = timeshifts["first"] - timeshifts["second"] if timedelta > pd.Timedelta(0): - logger.warning('Timedelta between shifts after stretching: {}'.format(timedelta)) - logger.info('Timeshift for {}: {}'.format(column, timeshifts['first'])) - self.sources[column]['timeshift'] = timeshifts['first'] + logger.warning( + "Timedelta between shifts after stretching: {}".format( + timedelta + ) + ) + logger.info("Timeshift for {}: {}".format(column, timeshifts["first"])) + self.sources[column]["timeshift"] = timeshifts["first"] def get_sync_params(self, recalculate=False): - selected_keys = ['timeshift', 'stretch_factor'] - if recalculate or 'timeshift' not in self.sources[self.ref_source_name]: + selected_keys = ["timeshift", "stretch_factor"] + if recalculate or "timeshift" not in self.sources[self.ref_source_name]: self._calculate_sync_params() return { source_name: { key: value for key, value in source.items() if key in selected_keys - } for source_name, source in self.sources.items()} + } + for source_name, source in self.sources.items() + } - def get_synced_data(self, recalculate=False): + def get_synced_data(self, recalculate=False) -> Dict[str, pd.DataFrame]: self.get_sync_params(recalculate) synced_data = {} start_time = self.ref_signals.index.min() for source_name, source in self.sources.items(): - data = source['data'].copy() - if source['stretch_factor'] != 1: - data = Synchronizer._stretch_signals(data, source['stretch_factor'], start_time) - if source['timeshift'] is not None: - data = data.shift(1, freq=source['timeshift'] / 2) + data = source["data"].copy() + if source["stretch_factor"] != 1: + data = Synchronizer._stretch_signals( + data, source["stretch_factor"], start_time + ) + if source["timeshift"] is not None: + data = data.shift(1, freq=source["timeshift"]) synced_data[source_name] = data return synced_data + def save_pickles(self, path) -> Dict[str, pd.DataFrame]: + """ + Save a pickled, synced, dataframe for each source file. Does not save a total table. Sync parameters are saved as SYNC.PICKLE. + + Returns the synced data. Sync parameter dataframe is in a dictionary entry with the key "SYNC". + """ + sync_params = pd.DataFrame(self.get_sync_params()) + synced_data = self.get_synced_data() + + sync_params.to_csv(os.path.join(path, "SYNC.csv")) + + for source_name, synced_df in synced_data.items(): + synced_df.to_pickle(os.path.join(path, f"{source_name.upper()}.PICKLE")) + + return {**synced_data, "SYNC": sync_params} + def save_data(self, path, tables=None, save_total_table=True): - if 'SYNC' in tables.keys(): - raise ValueError('SYNC must not be one of the table names. It is reserved for the synchronization paramters.') + if "SYNC" in tables.keys(): + raise ValueError( + "SYNC must not be one of the table names. It is reserved for the synchronization paramters." + ) - if save_total_table and 'TOTAL' in tables.keys(): - raise ValueError('TOTAL must not be one of the table names, if the table with all data should be saved.') + if save_total_table and "TOTAL" in tables.keys(): + raise ValueError( + "TOTAL must not be one of the table names, if the table with all data should be saved." + ) sync_params = self.get_sync_params() synced_data = self.get_synced_data() # Save sync params - pd.DataFrame(sync_params).to_csv(os.path.join(path, 'SYNC.csv')) + pd.DataFrame(sync_params).to_csv(os.path.join(path, "SYNC.csv")) # Save custom tables logger.info(tables) @@ -216,27 +319,39 @@ def save_data(self, path, tables=None, save_total_table=True): for table_name, table_spec in tables.items(): table_df = pd.DataFrame() if self.tags is not None: - table_df = table_df.join(self.tags.data, how='outer') + table_df = table_df.join(self.tags.data, how="outer") for source_name, source_columns in table_spec.items(): # create dataframe for each source source_df = pd.DataFrame() for column in source_columns: if column in synced_data[source_name].columns: # join selected signals to device dataframe - source_df = source_df.join(synced_data[source_name][column], how='outer') + source_df = source_df.join( + synced_data[source_name][column], how="outer" + ) if not source_df.empty: # add device signals to general dataframe - source_df = source_df.rename(lambda x: '{prefix}_{column}'.format(prefix=source_name, column=x), axis='columns') - table_df = table_df.join(source_df, how='outer') - table_df.dropna(axis='index', how='all', inplace=True) - table_df.to_csv(os.path.join(path, '{filename}.csv'.format(filename=table_name))) + source_df = source_df.rename( + lambda x: "{prefix}_{column}".format( + prefix=source_name, column=x + ), + axis="columns", + ) + table_df = table_df.join(source_df, how="outer") + table_df.dropna(axis="index", how="all", inplace=True) + table_df.to_csv( + os.path.join(path, "{filename}.csv".format(filename=table_name)) + ) # Save table with total data if save_total_table: total_table = pd.DataFrame() if self.tags is not None: - total_table = total_table.join(self.tags.data, how='outer') + total_table = total_table.join(self.tags.data, how="outer") for source_name, data in synced_data.items(): - source_df = data.rename(lambda x: '{prefix}_{column}'.format(prefix=source_name, column=x), axis='columns') - total_table = total_table.join(source_df, how='outer') - total_table.to_csv(os.path.join(path, 'TOTAL.csv')) + source_df = data.rename( + lambda x: "{prefix}_{column}".format(prefix=source_name, column=x), + axis="columns", + ) + total_table = total_table.join(source_df, how="outer") + total_table.to_csv(os.path.join(path, "TOTAL.csv")) diff --git a/setup.py b/setup.py index 02ed3a3..6d853dc 100644 --- a/setup.py +++ b/setup.py @@ -12,20 +12,20 @@ from setuptools import find_packages, setup, Command # Package meta-data. -NAME = 'jointly' -DESCRIPTION = 'Synchronize multiple signals from different sources.' +NAME = "jointly" +DESCRIPTION = "Synchronize multiple signals from different sources." URL = None -EMAIL = 'felix.musmann@student.hpi.de' -AUTHOR = 'Felix Musmann' -REQUIRES_PYTHON = '>=3.6.0' -VERSION = '0.1.4' +EMAIL = "felix.musmann@student.hpi.de" +AUTHOR = "Felix Musmann" +REQUIRES_PYTHON = ">=3.7.0" +VERSION = "0.2.0" # What packages are required for this module to be executed? REQUIRED = [ - 'matplotlib', - 'numpy', - 'pandas', - 'scipy', + "matplotlib", + "numpy", + "pandas", + "scipy", ] # What packages are optional? @@ -43,8 +43,8 @@ # Import the README and use it as the long-description. # Note: this will only work if 'README.md' is present in your MANIFEST.in file! try: - with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: - long_description = '\n' + f.read() + with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: + long_description = "\n" + f.read() except FileNotFoundError: long_description = DESCRIPTION @@ -52,22 +52,22 @@ about = {} if not VERSION: project_slug = NAME.lower().replace("-", "_").replace(" ", "_") - with open(os.path.join(here, project_slug, '__version__.py')) as f: + with open(os.path.join(here, project_slug, "__version__.py")) as f: exec(f.read(), about) else: - about['__version__'] = VERSION + about["__version__"] = VERSION class UploadCommand(Command): """Support setup.py upload.""" - description = 'Build and publish the package.' + description = "Build and publish the package." user_options = [] @staticmethod def status(s): """Prints things in bold.""" - print('\033[1m{0}\033[0m'.format(s)) + print("\033[1m{0}\033[0m".format(s)) def initialize_options(self): pass @@ -77,20 +77,20 @@ def finalize_options(self): def run(self): try: - self.status('Removing previous builds…') - rmtree(os.path.join(here, 'dist')) + self.status("Removing previous builds…") + rmtree(os.path.join(here, "dist")) except OSError: pass - self.status('Building Source and Wheel (universal) distribution…') - os.system('{0} setup.py sdist bdist_wheel --universal'.format(sys.executable)) + self.status("Building Source and Wheel (universal) distribution…") + os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) - self.status('Uploading the package to PyPI via Twine…') - os.system('twine upload dist/*') + self.status("Uploading the package to PyPI via Twine…") + os.system("twine upload dist/*") - self.status('Pushing git tags…') - os.system('git tag v{0}'.format(about['__version__'])) - os.system('git push --tags') + self.status("Pushing git tags…") + os.system("git tag v{0}".format(about["__version__"])) + os.system("git push --tags") sys.exit() @@ -98,10 +98,10 @@ def run(self): # Where the magic happens: setup( name=NAME, - version=about['__version__'], + version=about["__version__"], description=DESCRIPTION, long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", author=AUTHOR, author_email=EMAIL, python_requires=REQUIRES_PYTHON, @@ -109,26 +109,25 @@ def run(self): packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), # If your package is a single module, use this instead of 'packages': # py_modules=['mypackage'], - # entry_points={ # 'console_scripts': ['mycli=mymodule:cli'], # }, install_requires=REQUIRED, extras_require=EXTRAS, include_package_data=True, - license='MIT', + license="MIT", classifiers=[ # Trove classifiers # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy' + "License :: OSI Approved :: MIT License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", ], # $ setup.py publish support. cmdclass={ - 'upload': UploadCommand, + "upload": UploadCommand, }, )