Skip to content

Commit

Permalink
[REF] Modularize metric calculation (#447)
Browse files Browse the repository at this point in the history
* Reorganize metrics.

* Some more work on organizing metrics.

* Add signal_minus_noise_z metric.

* Variable name change.

* Move comptable.

* Move T2* cap into decay function.

* Split up metric files.

* Adjust cluster-extent thresholding to match across maps.

* Partially address review.

Lots of great refactoring by @rmarkello.

* Make DICE broadcastable.

* Clean up signal-noise metrics and fix compute_countnoise.

* Simplify calculate_z_maps.

* Fix dice (thanks @rmarkello)

* Improve documentation.

* Fix import.

* Fix imports.

* Get modularized metrics mostly working.

* Fix bugs in metric calculations.

All metrics should be calculated on *masked* data. Any metric maps
should also be masked.

* Revert changes to decision tree.

* Finish reverting.

* Fix viz.

* Fix style issues

* More???

* Fix bug in generate_metrics.

* Add initial tests.

* Improve docstrings, add shape checks, and reorder functions.

* Fix assertions.

* Fix style issue.

* Add metric submodules to API docs.

* Improve reporting for T_to_Z transform.

* Fix bugs in new modularized dependence metrics.
  • Loading branch information
tsalo authored and handwerkerd committed Nov 22, 2019
1 parent 5e0b3ee commit 913cd4c
Show file tree
Hide file tree
Showing 11 changed files with 1,066 additions and 17 deletions.
3 changes: 3 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ API
tedana.metrics.dependence_metrics
tedana.metrics.kundu_metrics

:template: module.rst
tedana.metrics.collect
tedana.metrics.dependence

.. _api_selection_ref:

Expand Down
11 changes: 10 additions & 1 deletion tedana/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
Functions to estimate S0 and T2* from multi-echo data.
"""
import logging
import scipy
import numpy as np
import scipy
from scipy import stats

from tedana import utils

LGR = logging.getLogger(__name__)
Expand Down Expand Up @@ -268,6 +270,13 @@ def fit_decay(data, tes, mask, adaptive_mask, fittype):
t2s_full = utils.unmask(t2s_full, mask)
s0_full = utils.unmask(s0_full, mask)

# set a hard cap for the T2* map
# anything that is 10x higher than the 99.5 %ile will be reset to 99.5 %ile
cap_t2s = stats.scoreatpercentile(t2s_limited.flatten(), 99.5,
interpolation_method='lower')
LGR.debug('Setting cap on T2* map at {:.5f}'.format(cap_t2s * 10))
t2s_limited[t2s_limited > cap_t2s * 10] = cap_t2s

return t2s_limited, s0_limited, t2s_full, s0_full


Expand Down
7 changes: 5 additions & 2 deletions tedana/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
# ex: set sts=4 ts=4 sw=4 et:

from .kundu_fit import (
dependence_metrics, kundu_metrics, get_coeffs, computefeats2
kundu_metrics, dependence_metrics
)
from .collect import (
generate_metrics
)

__all__ = [
'dependence_metrics', 'kundu_metrics', 'get_coeffs', 'computefeats2']
'dependence_metrics', 'kundu_metrics', 'generate_metrics']
139 changes: 139 additions & 0 deletions tedana/metrics/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""
Misc. utils for metric calculation.
"""
import numpy as np
from scipy import stats


def dependency_resolver(dict_, requested_metrics, base_inputs):
"""
Identify all necessary metrics based on a list of requested metrics and
the metrics each one requires to be calculated, as defined in a dictionary.
Parameters
----------
dict_ : :obj:`dict`
Dictionary containing lists, where each key is a metric name and its
associated value is the list of metrics or inputs required to calculate
it.
requested_metrics : :obj:`list`
Child metrics for which the function will determine parents.
base_inputs : :obj:`list`
A list of inputs to the metric collection function, to differentiate
them from metrics to be calculated.
Returns
-------
required_metrics :obj:`list`
A comprehensive list of all metrics and inputs required to generate all
of the requested inputs.
"""
not_found = [k for k in requested_metrics if k not in dict_.keys()]
if not_found:
raise ValueError('Unknown metric(s): {}'.format(', '.join(not_found)))

required_metrics = requested_metrics
while True:
required_metrics_new = required_metrics[:]
for k in required_metrics:
if k in dict_.keys():
new_metrics = dict_[k]
elif k not in base_inputs:
print("Warning: {} not found".format(k))
required_metrics_new += new_metrics
if set(required_metrics) == set(required_metrics_new):
# There are no more parent metrics to calculate
break
else:
required_metrics = required_metrics_new
return required_metrics


def determine_signs(weights, axis=0):
"""
Determine component-wise optimal signs using voxel-wise parameter estimates.
Parameters
----------
weights : (S x C) array_like
Parameter estimates for optimally combined data against the mixing
matrix.
Returns
-------
signs : (C) array_like
Array of 1 and -1 values corresponding to the appropriate flips for the
mixing matrix's component time series.
"""
# compute skews to determine signs based on unnormalized weights,
signs = stats.skew(weights, axis=axis)
signs /= np.abs(signs)
return signs


def flip_components(*args, signs):
"""
Flip an arbitrary set of input arrays based on a set of signs.
Parameters
----------
*args : array_like
Any number of arrays with one dimension the same length as signs.
If multiple dimensions share the same size as signs, behavior of this
function will be unpredictable.
signs : array_like of :obj:`int`
Array of +/- 1 by which to flip the values in each argument.
Returns
-------
*args : array_like
Input arrays after sign flipping.
"""
assert signs.ndim == 1, 'Argument "signs" must be one-dimensional.'
for arg in args:
assert len(signs) in arg.shape, \
('Size of argument "signs" must match size of one dimension in '
'each of the input arguments.')
assert sum(x == len(signs) for x in arg.shape) == 1, \
('Only one dimension of each input argument can match the length '
'of argument "signs".')
# correct mixing & weights signs based on spatial distribution tails
return [arg * signs for arg in args]


def sort_df(df, by='kappa', ascending=False):
"""
Sort DataFrame and get index.
Parameters
----------
df : :obj:`pandas.DataFrame`
DataFrame to sort.
by : :obj:`str`, optional
Column by which to sort the DataFrame. Default is 'kappa'.
ascending : :obj:`bool`, optional
Whether to sort the DataFrame in ascending (True) or descending (False)
order. Default is False.
Returns
-------
df : :obj:`pandas.DataFrame`
DataFrame after sorting, with index resetted.
argsort : array_like
Sorting index.
"""
# Order of kwargs is preserved at 3.6+
argsort = df[by].argsort()
if not ascending:
argsort = argsort[::-1]
df = df.loc[argsort].reset_index(drop=True)
return df, argsort


def apply_sort(*args, sort_idx, axis=0):
"""
Apply a sorting index to an arbitrary set of arrays.
"""
for arg in args:
assert arg.shape[axis] == len(sort_idx)
return [np.take(arg, sort_idx, axis=axis) for arg in args]
Loading

0 comments on commit 913cd4c

Please sign in to comment.