-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[REF] Modularize metric calculation (#447)
* Reorganize metrics. * Some more work on organizing metrics. * Add signal_minus_noise_z metric. * Variable name change. * Move comptable. * Move T2* cap into decay function. * Split up metric files. * Adjust cluster-extent thresholding to match across maps. * Partially address review. Lots of great refactoring by @rmarkello. * Make DICE broadcastable. * Clean up signal-noise metrics and fix compute_countnoise. * Simplify calculate_z_maps. * Fix dice (thanks @rmarkello) * Improve documentation. * Fix import. * Fix imports. * Get modularized metrics mostly working. * Fix bugs in metric calculations. All metrics should be calculated on *masked* data. Any metric maps should also be masked. * Revert changes to decision tree. * Finish reverting. * Fix viz. * Fix style issues * More??? * Fix bug in generate_metrics. * Add initial tests. * Improve docstrings, add shape checks, and reorder functions. * Fix assertions. * Fix style issue. * Add metric submodules to API docs. * Improve reporting for T_to_Z transform. * Fix bugs in new modularized dependence metrics.
- Loading branch information
1 parent
5e0b3ee
commit 913cd4c
Showing
11 changed files
with
1,066 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
""" | ||
Misc. utils for metric calculation. | ||
""" | ||
import numpy as np | ||
from scipy import stats | ||
|
||
|
||
def dependency_resolver(dict_, requested_metrics, base_inputs): | ||
""" | ||
Identify all necessary metrics based on a list of requested metrics and | ||
the metrics each one requires to be calculated, as defined in a dictionary. | ||
Parameters | ||
---------- | ||
dict_ : :obj:`dict` | ||
Dictionary containing lists, where each key is a metric name and its | ||
associated value is the list of metrics or inputs required to calculate | ||
it. | ||
requested_metrics : :obj:`list` | ||
Child metrics for which the function will determine parents. | ||
base_inputs : :obj:`list` | ||
A list of inputs to the metric collection function, to differentiate | ||
them from metrics to be calculated. | ||
Returns | ||
------- | ||
required_metrics :obj:`list` | ||
A comprehensive list of all metrics and inputs required to generate all | ||
of the requested inputs. | ||
""" | ||
not_found = [k for k in requested_metrics if k not in dict_.keys()] | ||
if not_found: | ||
raise ValueError('Unknown metric(s): {}'.format(', '.join(not_found))) | ||
|
||
required_metrics = requested_metrics | ||
while True: | ||
required_metrics_new = required_metrics[:] | ||
for k in required_metrics: | ||
if k in dict_.keys(): | ||
new_metrics = dict_[k] | ||
elif k not in base_inputs: | ||
print("Warning: {} not found".format(k)) | ||
required_metrics_new += new_metrics | ||
if set(required_metrics) == set(required_metrics_new): | ||
# There are no more parent metrics to calculate | ||
break | ||
else: | ||
required_metrics = required_metrics_new | ||
return required_metrics | ||
|
||
|
||
def determine_signs(weights, axis=0): | ||
""" | ||
Determine component-wise optimal signs using voxel-wise parameter estimates. | ||
Parameters | ||
---------- | ||
weights : (S x C) array_like | ||
Parameter estimates for optimally combined data against the mixing | ||
matrix. | ||
Returns | ||
------- | ||
signs : (C) array_like | ||
Array of 1 and -1 values corresponding to the appropriate flips for the | ||
mixing matrix's component time series. | ||
""" | ||
# compute skews to determine signs based on unnormalized weights, | ||
signs = stats.skew(weights, axis=axis) | ||
signs /= np.abs(signs) | ||
return signs | ||
|
||
|
||
def flip_components(*args, signs): | ||
""" | ||
Flip an arbitrary set of input arrays based on a set of signs. | ||
Parameters | ||
---------- | ||
*args : array_like | ||
Any number of arrays with one dimension the same length as signs. | ||
If multiple dimensions share the same size as signs, behavior of this | ||
function will be unpredictable. | ||
signs : array_like of :obj:`int` | ||
Array of +/- 1 by which to flip the values in each argument. | ||
Returns | ||
------- | ||
*args : array_like | ||
Input arrays after sign flipping. | ||
""" | ||
assert signs.ndim == 1, 'Argument "signs" must be one-dimensional.' | ||
for arg in args: | ||
assert len(signs) in arg.shape, \ | ||
('Size of argument "signs" must match size of one dimension in ' | ||
'each of the input arguments.') | ||
assert sum(x == len(signs) for x in arg.shape) == 1, \ | ||
('Only one dimension of each input argument can match the length ' | ||
'of argument "signs".') | ||
# correct mixing & weights signs based on spatial distribution tails | ||
return [arg * signs for arg in args] | ||
|
||
|
||
def sort_df(df, by='kappa', ascending=False): | ||
""" | ||
Sort DataFrame and get index. | ||
Parameters | ||
---------- | ||
df : :obj:`pandas.DataFrame` | ||
DataFrame to sort. | ||
by : :obj:`str`, optional | ||
Column by which to sort the DataFrame. Default is 'kappa'. | ||
ascending : :obj:`bool`, optional | ||
Whether to sort the DataFrame in ascending (True) or descending (False) | ||
order. Default is False. | ||
Returns | ||
------- | ||
df : :obj:`pandas.DataFrame` | ||
DataFrame after sorting, with index resetted. | ||
argsort : array_like | ||
Sorting index. | ||
""" | ||
# Order of kwargs is preserved at 3.6+ | ||
argsort = df[by].argsort() | ||
if not ascending: | ||
argsort = argsort[::-1] | ||
df = df.loc[argsort].reset_index(drop=True) | ||
return df, argsort | ||
|
||
|
||
def apply_sort(*args, sort_idx, axis=0): | ||
""" | ||
Apply a sorting index to an arbitrary set of arrays. | ||
""" | ||
for arg in args: | ||
assert arg.shape[axis] == len(sort_idx) | ||
return [np.take(arg, sort_idx, axis=axis) for arg in args] |
Oops, something went wrong.