diff --git a/tedana/decay.py b/tedana/decay.py index 8304d30e6..63bd4b617 100644 --- a/tedana/decay.py +++ b/tedana/decay.py @@ -11,6 +11,35 @@ RefLGR = logging.getLogger('REFERENCES') +def _apply_t2s_floor(t2s, echo_times): + """ + Apply a floor to T2* values to prevent zero division errors during + optimal combination. + + Parameters + ---------- + t2s : (S,) array_like + T2* estimates. + echo_times : (E,) array_like + Echo times in milliseconds. + + Returns + ------- + t2s_corrected : (S,) array_like + T2* estimates with very small, positive values replaced with a floor value. + """ + t2s_corrected = t2s.copy() + echo_times = np.asarray(echo_times) + if echo_times.ndim == 1: + echo_times = echo_times[:, None] + + eps = np.finfo(dtype=t2s.dtype).eps # smallest value for datatype + temp_arr = np.exp(-echo_times / t2s) # (E x V) array + bad_voxel_idx = np.any(temp_arr == 0, axis=0) & (t2s != 0) + t2s_corrected[bad_voxel_idx] = np.min(-echo_times) / np.log(eps) + return t2s_corrected + + def monoexponential(tes, s0, t2star): """ Specifies a monoexponential model for use with scipy curve fitting @@ -258,9 +287,11 @@ def fit_decay(data, tes, mask, adaptive_mask, fittype): t2s_limited[np.isinf(t2s_limited)] = 500. # why 500? # let's get rid of negative values, but keep zeros where limited != full t2s_limited[(adaptive_mask_masked > 1) & (t2s_limited <= 0)] = 1. + t2s_limited = _apply_t2s_floor(t2s_limited, tes) s0_limited[np.isnan(s0_limited)] = 0. # why 0? t2s_full[np.isinf(t2s_full)] = 500. # why 500? t2s_full[t2s_full <= 0] = 1. # let's get rid of negative values! + t2s_full = _apply_t2s_floor(t2s_full, tes) s0_full[np.isnan(s0_full)] = 0. # why 0? t2s_limited = utils.unmask(t2s_limited, mask) diff --git a/tedana/tests/test_decay.py b/tedana/tests/test_decay.py index 2e8eb794a..953949e40 100644 --- a/tedana/tests/test_decay.py +++ b/tedana/tests/test_decay.py @@ -7,7 +7,7 @@ import numpy as np import pytest -from tedana import io, utils, decay as me +from tedana import io, utils, combine, decay as me from tedana.tests.utils import get_test_data_path @@ -58,6 +58,30 @@ def test_fit_decay_ts(testdata1): assert s0vG.ndim == 2 +def test__apply_t2s_floor(): + """ + _apply_t2s_floor applies a floor to T2* values to prevent a ZeroDivisionError during + optimal combination. + """ + n_voxels, n_echos, n_trs = 100, 5, 25 + echo_times = np.array([2, 23, 54, 75, 96]) + me_data = np.random.random((n_voxels, n_echos, n_trs)) + t2s = np.random.random((n_voxels)) * 1000 + t2s[t2s < 1] = 1 # Crop at 1 ms to be safe + t2s[0] = 0.001 + + # First establish a failure + with pytest.raises(ZeroDivisionError): + _ = combine._combine_t2s(me_data, echo_times[None, :], t2s[:, None]) + + # Now correct the T2* map and get a successful result. + t2s_corrected = me._apply_t2s_floor(t2s, echo_times) + assert t2s_corrected[0] != t2s[0] # First value should be corrected + assert np.array_equal(t2s_corrected[1:], t2s[1:]) # No other values should be corrected + combined = combine._combine_t2s(me_data, echo_times[None, :], t2s_corrected[:, None]) + assert np.all(combined != 0) + + # SMOKE TESTS def test_smoke_fit_decay():