Skip to content

Commit

Permalink
Refactor gscontrol module (ME-ICA#1086)
Browse files Browse the repository at this point in the history
* Refactor the gscontrol function.

* Update docstring.

* Fix things.

* Update ica_reclassify.py

* Update fiu_four_echo_outputs.txt

* Update.

* Incorporate classification tags into MIR.

* Update gscontrol.py

* Clean up gscontrol_raw.

* Fix.

* Update reclassify_debug_out.txt

* Update gscontrol.py
  • Loading branch information
tsalo authored Aug 13, 2024
1 parent b215083 commit 432b0fe
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 145 deletions.
2 changes: 1 addition & 1 deletion docs/approach.rst
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ presented at MRITogether 2022 for a hands-on tutorial.
Removal of spatially diffuse noise (optional)
*********************************************

:func:`tedana.gscontrol.gscontrol_raw`, :func:`tedana.gscontrol.gscontrol_mmix`
:func:`tedana.gscontrol.gscontrol_raw`, :func:`tedana.gscontrol.minimum_image_regression`

Due to the constraints of ICA, TEDICA is able to identify and remove spatially
localized noise components, but it cannot identify components that are spread
Expand Down
280 changes: 174 additions & 106 deletions tedana/gscontrol.py

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions tedana/metrics/dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ def calculate_betas(
betas : (M [x E] x C) array_like
Unstandardized parameter estimates
"""
if len(data.shape) == 2:
if data.ndim == 2:
data_optcom = data
assert data_optcom.shape[1] == mixing.shape[0]
# mean-center optimally-combined data
data_optcom_dm = data_optcom - data_optcom.mean(axis=-1, keepdims=True)
# betas are the result of a normal OLS fit of the mixing matrix
# against the mean-center data
# betas are from a normal OLS fit of the mixing matrix against the mean-centered data
betas = get_coeffs(data_optcom_dm, mixing)
return betas

else:
betas = np.zeros([data.shape[0], data.shape[1], mixing.shape[1]])
for n_echo in range(data.shape[1]):
betas[:, n_echo, :] = get_coeffs(data[:, n_echo, :], mixing)
return betas

return betas


def calculate_psc(
Expand Down
33 changes: 17 additions & 16 deletions tedana/reporting/static_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,23 +164,24 @@ def carpet_plot(
)
)

mir_denoised_img = io_generator.get_name("ICA accepted mir denoised img")
fig, ax = plt.subplots(figsize=(14, 7))
plotting.plot_carpet(
mir_denoised_img,
mask_img,
figure=fig,
axes=ax,
title="High-Kappa Data (Post-MIR)",
)
fig.tight_layout()
fig.savefig(
os.path.join(
io_generator.out_dir,
"figures",
f"{io_generator.prefix}carpet_accepted_mir.svg",
if io_generator.verbose:
mir_denoised_img = io_generator.get_name("ICA accepted mir denoised img")
fig, ax = plt.subplots(figsize=(14, 7))
plotting.plot_carpet(
mir_denoised_img,
mask_img,
figure=fig,
axes=ax,
title="High-Kappa Data (Post-MIR)",
)
fig.tight_layout()
fig.savefig(
os.path.join(
io_generator.out_dir,
"figures",
f"{io_generator.prefix}carpet_accepted_mir.svg",
)
)
)


def plot_component(
Expand Down
3 changes: 1 addition & 2 deletions tedana/tests/data/reclassify_debug_out.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@ sub-testymctestface_references.bib
sub-testymctestface_report.txt
sub-testymctestface_betas_OC.nii.gz
sub-testymctestface_betas_hik_OC.nii.gz
sub-testymctestface_betas_hik_OC_MIR.nii.gz
sub-testymctestface_dataset_description.json
sub-testymctestface_dn_ts_OC.nii.gz
sub-testymctestface_dn_ts_OC_MIR.nii.gz
sub-testymctestface_feats_OC2.nii.gz
sub-testymctestface_hik_ts_OC_MIR.nii.gz
sub-testymctestface_ica_components.nii.gz
sub-testymctestface_ica_cross_component_metrics.json
sub-testymctestface_ica_decision_tree.json
Expand All @@ -22,3 +20,4 @@ sub-testymctestface_ica_orth_mixing.tsv
sub-testymctestface_ica_status_table.tsv
sub-testymctestface_registry.json
sub-testymctestface_sphis_hik.nii.gz
sub-testymctestface_confounds.tsv
33 changes: 21 additions & 12 deletions tedana/tests/test_gscontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@


def test_break_gscontrol_raw():
"""
Ensure that gscontrol_raw fails when input data do not have the right.
shapes.
"""
"""Ensure that gscontrol_raw fails when input data do not have the right shapes."""
n_samples, n_echos, n_vols = 10000, 4, 100
catd = np.empty((n_samples, n_echos, n_vols))
optcom = np.empty((n_samples, n_vols))
Expand All @@ -27,28 +23,41 @@ def test_break_gscontrol_raw():
catd = np.empty((n_samples + 1, n_echos, n_vols))
with pytest.raises(ValueError) as e_info:
gsc.gscontrol_raw(
catd=catd, optcom=optcom, n_echos=n_echos, io_generator=io_generator, dtrank=4
data_cat=catd,
data_optcom=optcom,
n_echos=n_echos,
io_generator=io_generator,
dtrank=4,
)
assert str(e_info.value) == (
f"First dimensions of catd ({catd.shape[0]}) and optcom ({optcom.shape[0]}) do not match"
f"First dimensions of data_cat ({catd.shape[0]}) and data_optcom ({optcom.shape[0]}) "
"do not match"
)

catd = np.empty((n_samples, n_echos + 1, n_vols))
with pytest.raises(ValueError) as e_info:
gsc.gscontrol_raw(
catd=catd, optcom=optcom, n_echos=n_echos, io_generator=io_generator, dtrank=4
data_cat=catd,
data_optcom=optcom,
n_echos=n_echos,
io_generator=io_generator,
dtrank=4,
)
assert str(e_info.value) == (
f"Second dimension of catd ({catd.shape[1]}) does not match n_echos ({n_echos})"
f"Second dimension of data_cat ({catd.shape[1]}) does not match n_echos ({n_echos})"
)

catd = np.empty((n_samples, n_echos, n_vols))
optcom = np.empty((n_samples, n_vols + 1))
with pytest.raises(ValueError) as e_info:
gsc.gscontrol_raw(
catd=catd, optcom=optcom, n_echos=n_echos, io_generator=io_generator, dtrank=4
data_cat=catd,
data_optcom=optcom,
n_echos=n_echos,
io_generator=io_generator,
dtrank=4,
)
assert str(e_info.value) == (
f"Third dimension of catd ({catd.shape[2]}) does not match "
f"second dimension of optcom ({optcom.shape[1]})"
f"Third dimension of data_cat ({catd.shape[2]}) does not match "
f"second dimension of data_optcom ({optcom.shape[1]})"
)
9 changes: 8 additions & 1 deletion tedana/workflows/ica_reclassify.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,14 @@ def ica_reclassify_workflow(

if mir:
io_generator.overwrite = True
gsc.minimum_image_regression(data_oc, mmix, mask_denoise, comptable, io_generator)
gsc.minimum_image_regression(
data_optcom=data_oc,
mixing=mmix,
mask=mask_denoise,
comptable=comptable,
classification_tags=selector.classification_tags,
io_generator=io_generator,
)
io_generator.overwrite = False

# Write out BIDS-compatible description file
Expand Down
16 changes: 14 additions & 2 deletions tedana/workflows/tedana.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,12 @@ def tedana_workflow(

if "gsr" in gscontrol:
# regress out global signal
catd, data_oc = gsc.gscontrol_raw(catd, data_oc, n_echos, io_generator)
catd, data_oc = gsc.gscontrol_raw(
data_cat=catd,
data_optcom=data_oc,
n_echos=n_echos,
io_generator=io_generator,
)

fout = io_generator.save_file(data_oc, "combined img")
LGR.info(f"Writing optimally combined data set: {fout}")
Expand Down Expand Up @@ -886,7 +891,14 @@ def tedana_workflow(
)

if "mir" in gscontrol:
gsc.minimum_image_regression(data_oc, mmix, mask_denoise, comptable, io_generator)
gsc.minimum_image_regression(
data_optcom=data_oc,
mixing=mmix,
mask=mask_denoise,
comptable=comptable,
classification_tags=selector.classification_tags,
io_generator=io_generator,
)

if verbose:
io.writeresults_echoes(catd, mmix, mask_denoise, comptable, io_generator)
Expand Down

0 comments on commit 432b0fe

Please sign in to comment.