Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Output RMSE map and time series for decay model fit #1044

Merged
merged 44 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
cc23fb7
Draft function to calculate decay model fit.
tsalo Feb 21, 2024
12cd117
Calculate root mean squared error instead.
tsalo Feb 21, 2024
2668f7d
Incorporate metrics.
tsalo Feb 21, 2024
a11f4ab
Merge remote-tracking branch 'upstream/main' into fit-model-fit
tsalo Feb 24, 2024
175c972
Output RMSE results.
tsalo Feb 24, 2024
123c916
Output results in tedana.
tsalo Feb 24, 2024
775d732
Hopefully fix things.
tsalo Feb 24, 2024
40e45c7
Update decay.py
tsalo Feb 24, 2024
1d315d1
Try improving performance.
tsalo Feb 25, 2024
6a2ce9a
Update decay.py
tsalo Feb 25, 2024
21e0584
Fix again.
tsalo Feb 25, 2024
8d15dd2
Use tqdm.
tsalo Feb 25, 2024
5e13af6
Update decay.py
tsalo Feb 25, 2024
dbe731c
Update decay.py
tsalo Feb 25, 2024
75549c1
Update decay.py
tsalo Feb 25, 2024
02dd091
Update expected outputs.
tsalo Feb 26, 2024
4f49d0e
Merge remote-tracking branch 'upstream/main' into fit-model-fit
tsalo Feb 29, 2024
d2a8d9e
Merge remote-tracking branch 'upstream/main' into fit-model-fit
tsalo Mar 26, 2024
e3704a0
Add figures.
tsalo Apr 12, 2024
107805e
Update outputs.
tsalo Apr 12, 2024
35bdebe
Merge remote-tracking branch 'upstream/main' into fit-model-fit
tsalo Apr 14, 2024
8d366db
Include global signal in confounds file.
tsalo Apr 16, 2024
941e2c7
Merge remote-tracking branch 'upstream/main' into fit-model-fit
tsalo Apr 16, 2024
c48f3ef
Merge remote-tracking branch 'upstream/main' into fit-model-fit
tsalo Apr 16, 2024
37bf10b
Update fiu_four_echo_outputs.txt
tsalo Apr 16, 2024
4a50094
Merge remote-tracking branch 'upstream/main' into fit-model-fit
tsalo Apr 18, 2024
4240461
Rename function.
tsalo Apr 19, 2024
9f7ea0e
Rename function.
tsalo Apr 19, 2024
ecfdabf
Update tedana.py
tsalo Apr 19, 2024
efd2ca1
Update tedana/decay.py
tsalo Apr 19, 2024
0bbfc96
Update decay.py
tsalo Apr 22, 2024
da96f2f
Update decay.py
tsalo Apr 22, 2024
ffb8f53
Whoops.
tsalo Apr 22, 2024
69ec73b
Apply suggestions from code review
tsalo Apr 26, 2024
3f22f90
Fix things maybe.
tsalo Apr 26, 2024
499e8bf
Fix things.
tsalo Apr 26, 2024
373b3eb
Update decay.py
tsalo Apr 26, 2024
4833cc0
Remove any files that are built through appending.
tsalo Apr 26, 2024
c5c1046
Update outputs.
tsalo Apr 26, 2024
ed98c7b
Add section on plots to docs.
tsalo Apr 26, 2024
f620ba7
Fix the description.
tsalo Apr 26, 2024
92758c1
Update docs/outputs.rst
tsalo Apr 26, 2024
4fb6108
Update docs/outputs.rst
tsalo Apr 26, 2024
730354a
Fix docstring.
tsalo Apr 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/_static/rmse_plots.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
26 changes: 24 additions & 2 deletions docs/outputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ report.txt A s
"high kappa ts img": desc-optcomAccepted_bold.nii.gz High-kappa time series. This dataset does not
include thermal noise or low variance components.
Not the recommended dataset for analysis.
"confounds tsv": desc-confounds_timeseries.tsv Summary time series measures, including RMSE measures
of T2*/S0 model fit.
references.bib The BibTeX entries for references cited in
report.txt.

Expand Down Expand Up @@ -167,8 +169,8 @@ If ``gscontrol`` includes 'gsr'
Key: Filename Content
================================================================= =====================================================
"gs img": desc-globalSignal_map.nii.gz Spatial global signal
"global signal time series tsv": desc-globalSignal_timeseries.tsv Time series of global signal from optimally combined
data.
"confounds tsv": desc-confounds_timeseries.tsv Time series of global signal from optimally combined
data will be added to this file.
"has gs combined img": desc-optcomWithGlobalSignal_bold.nii.gz Optimally combined time series with global signal
retained.
"removed gs combined img": desc-optcomNoGlobalSignal_bold.nii.gz Optimally combined time series with global signal
Expand Down Expand Up @@ -563,6 +565,26 @@ It is important to note that the histogram is limited from 0 to the 98th percent
:height: 400px


*********************
Decay Model Fit Plots
*********************

Below the T2* and S0 summary plots are the decay model fit plots.
These plots show residual mean squared error (RMSE) values for the
monoexponential decay model, based on the T2* and S0 maps.

The first plot is the mean RMSE brain plot, which shows the mean RMSE over time for each voxel in the brain.

The second plot is a time series of RMSE values across the brain, over time.
This plot includes the median RMSE time series,
along with an error band representing the 25th and 75th percentiles,
and dotted lines indicating the minimum and maximum RMSE values.

.. image:: /_static/rmse_plots.png
:align: center
:height: 400px


**************************
Citable workflow summaries
**************************
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"scikit-learn>=0.21, <=1.4.2",
"scipy>=1.2.0, <=1.13.0",
"threadpoolctl",
"tqdm",
]
dynamic = ["version"]

Expand Down
109 changes: 107 additions & 2 deletions tedana/decay.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Functions to estimate S0 and T2* from multi-echo data."""

import logging
from typing import List, Literal, Tuple

import numpy as np
import numpy.matlib
import pandas as pd
import scipy
from scipy import stats
from tqdm.auto import tqdm

from tedana import utils

Expand Down Expand Up @@ -112,7 +116,7 @@ def fit_monoexponential(data_cat, echo_times, adaptive_mask, report=True):
"estimate T2* and S0. In cases of model fit failure, T2*/S0 "
"estimates from the log-linear fit were retained instead."
)
n_samp, n_echos, n_vols = data_cat.shape
n_samp, _, n_vols = data_cat.shape

# Currently unused
# fit_data = np.mean(data_cat, axis=2)
Expand Down Expand Up @@ -151,7 +155,7 @@ def fit_monoexponential(data_cat, echo_times, adaptive_mask, report=True):
# perform a monoexponential fit of echo times against MR signal
# using loglin estimates as initial starting points for fit
fail_count = 0
for voxel in voxel_idx:
for voxel in tqdm(voxel_idx, desc=f"{echo_num}-echo monoexponential"):
try:
popt, cov = scipy.optimize.curve_fit(
monoexponential,
Expand Down Expand Up @@ -460,3 +464,104 @@ def fit_decay_ts(data, tes, mask, adaptive_mask, fittype):
report = False

return t2s_limited_ts, s0_limited_ts, t2s_full_ts, s0_full_ts


def rmse_of_fit_decay_ts(
*,
data: np.ndarray,
tes: List[float],
adaptive_mask: np.ndarray,
t2s: np.ndarray,
s0: np.ndarray,
fitmode: Literal["all", "ts"],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Estimate model fit of voxel- and timepoint-wise monoexponential decay models to ``data``.

Parameters
----------
data : (S x E x T) :obj:`numpy.ndarray`
Multi-echo data array, where `S` is samples, `E` is echos, and `T` is time.
tes : (E,) :obj:`list`
Echo times.
adaptive_mask : (S,) :obj:`numpy.ndarray`
Array where each value indicates the number of echoes with good signal for that voxel.
This mask may be thresholded; for example, with values less than 3 set to 0.
For more information on thresholding, see :func:`~tedana.utils.make_adaptive_mask`.
t2s : (S [x T]) :obj:`numpy.ndarray`
Voxel-wise (and possibly volume-wise) T2* estimates from
:func:`~tedana.decay.fit_decay_ts`.
s0 : (S [x T]) :obj:`numpy.ndarray`
Voxel-wise (and possibly volume-wise) S0 estimates from :func:`~tedana.decay.fit_decay_ts`.
fitmode : {"fit", "all"}
Whether the T2* and S0 estimates are volume-wise ("fit") or not ("all").

Returns
-------
rmse_map : (S,) :obj:`numpy.ndarray`
Mean root mean squared error of the model fit across all volumes at each voxel.
rmse_df : :obj:`pandas.DataFrame`
Each column is the root mean squared error of the model fit at each timepoint.
Columns are mean, standard deviation, and percentiles across voxels. Column labels are
"rmse_mean", "rmse_std", "rmse_min", "rmse_percentile02", "rmse_percentile25",
"rmse_median", "rmse_percentile75", "rmse_percentile98", and "rmse_max"
"""
n_samples, _, n_vols = data.shape
tes = np.array(tes)

rmse = np.full([n_samples, n_vols], np.nan, dtype=np.float32)
# n_good_echoes interates from 2 through the number of echoes
# 0 and 1 are excluded because there aren't T2* and S0 estimates
# for less than 2 good echoes. 2 echoes will have a bad estimate so consider
# how/if we want to distinguish those
for n_good_echoes in range(2, len(tes) + 1):
# a boolean mask for voxels with a specific num of good echoes
use_vox = adaptive_mask == n_good_echoes
data_echo = data[use_vox, :n_good_echoes, :]
if fitmode == "all":
s0_echo = numpy.matlib.repmat(s0[use_vox].T, n_vols, 1).T
t2s_echo = numpy.matlib.repmat(t2s[use_vox], n_vols, 1).T
elif fitmode == "ts":
s0_echo = s0[use_vox, :]
t2s_echo = t2s[use_vox, :]
else:
raise ValueError(f"Unknown fitmode option {fitmode}")

predicted_data = np.full([use_vox.sum(), n_good_echoes, n_vols], np.nan, dtype=np.float32)
# Need to loop by echo since monoexponential can take either single vals for s0 and t2star
# or a single TE value.
# We could expand that func, but this is a functional solution
for echo_num in range(n_good_echoes):
predicted_data[:, echo_num, :] = monoexponential(
tes=tes[echo_num],
s0=s0_echo,
t2star=t2s_echo,
)
rmse[use_vox, :] = np.sqrt(np.mean((data_echo - predicted_data) ** 2, axis=1))

rmse_map = np.nanmean(rmse, axis=1)
rmse_timeseries = np.nanmean(rmse, axis=0)
rmse_sd_timeseries = np.nanstd(rmse, axis=0)
rmse_percentiles_timeseries = np.nanpercentile(rmse, [0, 2, 25, 50, 75, 98, 100], axis=0)

rmse_df = pd.DataFrame(
columns=[
"rmse_mean",
"rmse_std",
"rmse_min",
"rmse_percentile02",
"rmse_percentile25",
"rmse_median",
"rmse_percentile75",
"rmse_percentile98",
"rmse_max",
],
data=np.column_stack(
(
rmse_timeseries,
rmse_sd_timeseries,
rmse_percentiles_timeseries.T,
)
),
)

return rmse_map, rmse_df
2 changes: 1 addition & 1 deletion tedana/gscontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def gscontrol_raw(catd, optcom, n_echos, io_generator, dtrank=4):
glsig = stats.zscore(glsig, axis=None)

glsig_df = pd.DataFrame(data=glsig.T, columns=["global_signal"])
io_generator.save_file(glsig_df, "global signal time series tsv")
io_generator.add_df_to_file(glsig_df, "confounds tsv")
glbase = np.hstack([legendre_arr, glsig.T])

# Project global signal out of optimally combined data
Expand Down
33 changes: 33 additions & 0 deletions tedana/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ def __init__(
LGR.info(f"Generating figures directory: {self.figures_dir}")
os.mkdir(self.figures_dir)

# Remove files that are appended to instead of overwritten.
if overwrite:
files_to_remove = ["confounds tsv"]
for file_ in files_to_remove:
filepath = self.get_name(file_)
if op.exists(filepath):
os.remove(filepath)

def _determine_extension(self, description, name):
"""Infer the extension for a file based on its description.

Expand Down Expand Up @@ -346,6 +354,31 @@ def save_tsv(self, data, name):
deblanked = data.replace("", np.nan)
deblanked.to_csv(name, sep="\t", lineterminator="\n", na_rep="n/a", index=False)

def add_df_to_file(self, data, description, **kwargs):
"""Add a DataFrame to a tsv file, which may or may not exist.

Parameters
----------
data : dict or img_like or pandas.DataFrame
Data to save to file.
description : str
Description of the data, used to determine the appropriate filename from
``self.config``.

Returns
-------
name : str
The full file path of the saved file.
"""
name = self.get_name(description, **kwargs)
if op.isfile(name):
old_data = pd.read_table(name)
data = pd.concat([old_data, data], axis=1, ignore_index=False)

self.save_tsv(data, name)

return name

def save_self(self):
"""Save the registry to a json file.

Expand Down
7 changes: 7 additions & 0 deletions tedana/reporting/data/html/report_body_template.html
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ <h2>S0</h2>
<div class="carpet-plots-image">
<img id="s0Histogram" src="$s0Histogram" style="height: 500px" />
</div>
<h2>T2* and S0 model fit (RMSE). (Scaled between 2nd and 98th percentiles)</h2>
<div class="carpet-plots-image">
<img id="rmseBrainPlot" src="$rmseBrainPlot" style="height:500px;" />
</div>
<div class="carpet-plots-image">
<img id="rmseTimeseries" src="$rmseTimeseries" style="height:500px;" />
</div>
</div>
</div>
<div class="info">
Expand Down
5 changes: 5 additions & 0 deletions tedana/reporting/html_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def _update_template_bokeh(bokeh_id, info_table, about, prefix, references, boke
t2star_histogram = f"./figures/{prefix}t2star_histogram.svg"
s0_brain = f"./figures/{prefix}s0_brain.svg"
s0_histogram = f"./figures/{prefix}s0_histogram.svg"
rmse_brain = f"./figures/{prefix}rmse_brain.svg"
rmse_timeseries = f"./figures/{prefix}rmse_timeseries.svg"

# Convert bibtex to html
references, bibliography = _bib2html(references)
Expand All @@ -162,6 +164,7 @@ def _update_template_bokeh(bokeh_id, info_table, about, prefix, references, boke
body_template_path = resource_path.joinpath(body_template_name)
with open(str(body_template_path)) as body_file:
body_tpl = Template(body_file.read())

body = body_tpl.substitute(
content=bokeh_id,
info=info_table,
Expand All @@ -173,6 +176,8 @@ def _update_template_bokeh(bokeh_id, info_table, about, prefix, references, boke
t2starHistogram=t2star_histogram,
s0BrainPlot=s0_brain,
s0Histogram=s0_histogram,
rmseBrainPlot=rmse_brain,
rmseTimeseries=rmse_timeseries,
references=references,
javascript=bokeh_js,
buttons=buttons,
Expand Down
88 changes: 88 additions & 0 deletions tedana/reporting/static_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,94 @@ def plot_t2star_and_s0(
)


def plot_rmse(
*,
io_generator: io.OutputGenerator,
adaptive_mask: np.ndarray,
):
"""
Create a plot of the root mean square error (RMSE) for each component.

Parameters
----------
io_generator : :obj:`~tedana.io.OutputGenerator`
The output generator for this workflow
adaptive_mask : (S,) :obj:`numpy.ndarray`
A mask where each value is the number of good echoes.
Since the T2* and S0 estimations require a minimum of 2
good echoes, the outputted plots will only include mask
values of at least 2.
"""
import pandas as pd

rmse_img = io_generator.get_name("rmse img")
confounds_file = io_generator.get_name("confounds tsv")
# Mask that only includes values >=2 (i.e. at least 2 good echoes)
mask_img = io.new_nii_like(io_generator.reference_img, (adaptive_mask >= 2).astype(np.int32))

rmse_data = masking.apply_mask(rmse_img, mask_img)
rmse_p02, rmse_p98 = np.percentile(rmse_data, [2, 98])

# Get repetition time from reference image
tr = io_generator.reference_img.header.get_zooms()[-1]

# Load the confounds file
confounds_df = pd.read_table(confounds_file)

fig, ax = plt.subplots(figsize=(10, 6))
rmse_arr = confounds_df["rmse_median"].values
p25_arr = confounds_df["rmse_percentile25"].values
p75_arr = confounds_df["rmse_percentile75"].values
p02_arr = confounds_df["rmse_percentile02"].values
p98_arr = confounds_df["rmse_percentile98"].values
time_arr = np.arange(confounds_df.shape[0]) * tr
ax.plot(time_arr, rmse_arr, color="black")
ax.fill_between(
time_arr,
p25_arr,
p75_arr,
color="blue",
alpha=0.2,
)
ax.plot(time_arr, p02_arr, color="black", linestyle="dashed")
ax.plot(time_arr, p98_arr, color="black", linestyle="dashed")
ax.set_ylabel("RMSE", fontsize=16)
ax.set_xlabel(
"Time (s)",
fontsize=16,
)
ax.legend(["Median", "25th-75th percentiles", "2nd and 98th percentiles"])
ax.set_title("Root mean squared error of T2* and S0 fit across voxels", fontsize=20)
rmse_ts_plot = os.path.join(
io_generator.out_dir,
"figures",
f"{io_generator.prefix}rmse_timeseries.svg",
)
ax.set_xlim(0, time_arr[-2])
fig.savefig(rmse_ts_plot)
plt.close(fig)

# Plot RMSE
rmse_brain_plot = os.path.join(
io_generator.out_dir,
"figures",
f"{io_generator.prefix}rmse_brain.svg",
)
plotting.plot_stat_map(
rmse_img,
bg_img=None,
display_mode="mosaic",
cut_coords=4,
symmetric_cbar=False,
black_bg=True,
cmap="Reds",
vmin=rmse_p02,
vmax=rmse_p98,
annotate=False,
output_file=rmse_brain_plot,
)


def plot_adaptive_mask(
*,
optcom: np.ndarray,
Expand Down
Loading