diff --git a/phys2bids/physio_obj.py b/phys2bids/physio_obj.py index 61bdd88eb..3466aa626 100644 --- a/phys2bids/physio_obj.py +++ b/phys2bids/physio_obj.py @@ -4,6 +4,7 @@ """I/O objects for phys2bids.""" import logging +from copy import deepcopy from itertools import groupby import numpy as np @@ -228,16 +229,16 @@ class BlueprintInput(): def __init__(self, timeseries, freq, ch_name, units, trigger_idx, num_timepoints_found=None, thr=None, time_offset=0): """Initialise BlueprintInput (see class docstring).""" - self.timeseries = is_valid(timeseries, list, list_type=np.ndarray) - self.freq = has_size(is_valid(freq, list, + self.timeseries = deepcopy(is_valid(timeseries, list, list_type=np.ndarray)) + self.freq = deepcopy(has_size(is_valid(freq, list, list_type=(int, float)), - self.ch_amount, 0.0) - self.ch_name = has_size(ch_name, self.ch_amount, 'unknown') - self.units = has_size(units, self.ch_amount, '[]') - self.trigger_idx = is_valid(trigger_idx, int) - self.num_timepoints_found = num_timepoints_found - self.thr = thr - self.time_offset = time_offset + self.ch_amount, 0.0)) + self.ch_name = deepcopy(has_size(ch_name, self.ch_amount, 'unknown')) + self.units = deepcopy(has_size(units, self.ch_amount, '[]')) + self.trigger_idx = deepcopy(is_valid(trigger_idx, int)) + self.num_timepoints_found = deepcopy(num_timepoints_found) + self.thr = deepcopy(thr) + self.time_offset = deepcopy(time_offset) @property def ch_amount(self): @@ -456,7 +457,7 @@ def check_trigger_amount(self, thr=None, num_timepoints_expected=0, tr=0): # Use the trigger channel to find the TRs, # comparing it to a given threshold. trigger = self.timeseries[self.trigger_idx] - time = self.timeseries[0].copy() + time = self.timeseries[0] LGR.info(f'The trigger is in channel {self.trigger_idx}') # Check that trigger and time channels have the same length. # If not, resample time to the length of the trigger @@ -586,12 +587,12 @@ class BlueprintOutput(): def __init__(self, timeseries, freq, ch_name, units, start_time, filename=''): """Initialise BlueprintOutput (see class docstring).""" - self.timeseries = is_valid(timeseries, np.ndarray) - self.freq = is_valid(freq, (int, float)) - self.ch_name = has_size(ch_name, self.ch_amount, 'unknown') - self.units = has_size(units, self.ch_amount, '[]') - self.start_time = start_time - self.filename = is_valid(filename, str) + self.timeseries = deepcopy(is_valid(timeseries, np.ndarray)) + self.freq = deepcopy(is_valid(freq, (int, float))) + self.ch_name = deepcopy(has_size(ch_name, self.ch_amount, 'unknown')) + self.units = deepcopy(has_size(units, self.ch_amount, '[]')) + self.start_time = deepcopy(start_time) + self.filename = deepcopy(is_valid(filename, str)) @property def ch_amount(self): diff --git a/phys2bids/tests/test_physio_obj.py b/phys2bids/tests/test_physio_obj.py index 4fc1ae6d9..b28be27ba 100644 --- a/phys2bids/tests/test_physio_obj.py +++ b/phys2bids/tests/test_physio_obj.py @@ -89,7 +89,7 @@ def test_BlueprintInput(): test_time = np.array([0, 1, 1, 2, 3, 5, 8, 13]) test_trigger = np.array([0, 1, 0, 0, 0, 0, 0, 0]) test_chocolate = np.array([1, 0, 0, 1, 0, 0, 1, 0]) - test_timeseries = [test_time.copy(), test_trigger, test_chocolate] + test_timeseries = [test_time, test_trigger, test_chocolate] test_freq = [42.0, 3.14, 20.0] test_chn_name = ['time', 'trigger', 'chocolate'] test_units = ['s', 's', 'sweetness'] @@ -101,11 +101,12 @@ def test_BlueprintInput(): # Tests rename_channels new_names = ['trigger', 'time', 'lindt'] - blueprint_in.rename_channels(new_names.copy()) + blueprint_in.rename_channels(new_names) assert blueprint_in.ch_name == ['time', 'trigger', 'lindt'] # Tests return_index test_index = blueprint_in.return_index(1) + assert test_index[0] is not test_trigger assert (test_index[0] == test_trigger).all() assert test_index[1] == len(test_timeseries) assert test_index[2] == test_freq[1] @@ -140,7 +141,7 @@ def test_cta_time_interp(): """Test BlueprintInput.check_trigger_amount with time resampling.""" test_time = np.array([0, 7]) test_trigger = np.array([0, 1, 0, 0, 0, 0, 0, 0]) - test_timeseries = [test_time.copy(), test_trigger] + test_timeseries = [test_time, test_trigger] test_freq = [42.0, 3.14] test_chn_name = ['time', 'trigger'] test_units = ['s', 's']