Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Add an option to cut streamlines #1119

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions scilpy/tractograms/streamline_and_mask_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def cut_streamlines_with_mask(


def cut_streamlines_between_labels(
sft, label_data, label_ids=None, min_len=0, processes=1
sft, label_data, label_ids=None, min_len=0,
one_point_in_roi=False, no_point_in_roi=False, processes=1
):
"""
Cut streamlines so their segment are going from blob #1 to blob #2 in a
Expand All @@ -356,6 +357,10 @@ def cut_streamlines_between_labels(
in the label map will be used.
min_len: float
Minimum length from the resulting streamlines.
one_point_in_roi: bool
If True, one point in each ROI will be kept.
no_point_in_roi: bool
If True, no point in the ROIs will be kept.

Returns
-------
Expand Down Expand Up @@ -394,7 +399,8 @@ def cut_streamlines_between_labels(
# Trim streamlines with the mask and return the new streamlines
pool = Pool(processes)
lists_of_new_strmls = pool.starmap(
_cut_streamline_with_labels, [(i, s, pt, label_data_1, label_data_2)
_cut_streamline_with_labels, [(i, s, pt, label_data_1, label_data_2,
one_point_in_roi, no_point_in_roi)
for (i, s, pt) in zip(
indices, sft.streamlines,
points_to_idx)])
Expand All @@ -416,7 +422,8 @@ def cut_streamlines_between_labels(


def _cut_streamline_with_labels(
idx, streamline, pts_to_idx, roi_data_1, roi_data_2
idx, streamline, pts_to_idx, roi_data_1, roi_data_2,
one_point_in_roi=False, no_point_in_roi=False
):
"""
Cut streamlines so their segment are going from label mask #1 to label
Expand All @@ -435,6 +442,10 @@ def _cut_streamline_with_labels(
Boolean array representing the region #1.
roi_data_2: np.ndarray
Boolean array representing the region #2.
one_point_in_roi: bool
If True, one point in each ROI will be kept.
no_point_in_roi: bool
If True, no point in the ROIs will be kept.

Returns
-------
Expand All @@ -445,7 +456,9 @@ def _cut_streamline_with_labels(
# ROIs
in_strl_idx, out_strl_idx = _intersects_two_rois(roi_data_1,
roi_data_2,
idx)
idx,
one_point_in_roi=one_point_in_roi,
no_point_in_roi=no_point_in_roi)

cut_strl = None
# If the streamline intersects both ROIs
Expand Down Expand Up @@ -497,7 +510,8 @@ def _get_longest_streamline_segment_in_roi(all_strl_indices):
return strl_indices


def _intersects_two_rois(roi_data_1, roi_data_2, strl_indices):
def _intersects_two_rois(roi_data_1, roi_data_2, strl_indices,
one_point_in_roi=False, no_point_in_roi=False):
""" Find the first and last "voxels" of the streamline that are in the
ROIs.

Expand All @@ -509,6 +523,10 @@ def _intersects_two_rois(roi_data_1, roi_data_2, strl_indices):
Boolean array representing the region #2
strl_indices: list of tuple (N, 3)
3D indices of the voxels intersected by the streamline
one_point_in_roi: bool
If True, one point in each ROI will be kept.
no_point_in_roi: bool
If True, no point in the ROIs will be kept.

Returns
-------
Expand Down Expand Up @@ -549,8 +567,24 @@ def _intersects_two_rois(roi_data_1, roi_data_2, strl_indices):

# Get the index of the first and last "voxels" of the streamline that are
# in the ROIs
in_strl_idx = in_strl_indices[0]
out_strl_idx = out_strl_indices[-1]
if not one_point_in_roi and not no_point_in_roi:
in_strl_idx = in_strl_indices[0]
out_strl_idx = out_strl_indices[-1]
else:
if one_point_in_roi:
add_indice = 0
elif no_point_in_roi:
add_indice = 1

if in_strl_indices[-1] is not None:
in_strl_idx = in_strl_indices[-1] + add_indice
else:
in_strl_idx = None

if out_strl_indices[0] is not None:
out_strl_idx = out_strl_indices[0] - add_indice
else:
out_strl_idx = None

return in_strl_idx, out_strl_idx

Expand Down
21 changes: 21 additions & 0 deletions scripts/scil_tractogram_cut_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
--label: Label containing 2 blobs. Streamlines will be cut so they go from the
first label region to the second label region. The two blobs must be disjoint.

Default: Will cut the streamlines according to the mask. New streamlines
may be generated if the mask is disjoint.

--keep_longest: Will keep the longest segment of the streamline that is
within the mask. No new streamlines will be generated.

--trim_endpoints: Will only remove the endpoints of the streamlines that
are outside the mask. The middle part of the streamline may go
outside the mask, to compensate for hole in the mask for example. No new
streamlines will be generated.

Both scenarios will erase data_per_point and data_per_streamline. Streamlines
will be extended so they reach the boundary of the mask or the two labels,
therefore won't be equal to the input streamlines.
Expand Down Expand Up @@ -94,6 +105,14 @@ def _build_arg_parser():
help='If set, will only remove the endpoints of the '
'streamlines that are outside the mask.')

g1 = p.add_argument_group('Cutting options', 'Options for cutting '
'streamlines with --labels.')
g3 = g1.add_mutually_exclusive_group()
g3.add_argument('--one_point_in_roi', action='store_true',
help='If set, will keep one point in each label.')
g3.add_argument('--no_point_in_roi', action='store_true',
help='If set, will not keep any point in the labels.')

add_compression_arg(p)
add_overwrite_arg(p)
add_processes_arg(p)
Expand Down Expand Up @@ -153,6 +172,8 @@ def main():

new_sft = cut_streamlines_between_labels(
sft, label_data, args.label_ids, min_len=args.min_length,
one_point_in_roi=args.one_point_in_roi,
no_point_in_roi=args.no_point_in_roi,
processes=args.nbr_processes)

# Saving
Expand Down
26 changes: 26 additions & 0 deletions scripts/tests/test_tractogram_cut_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,29 @@ def test_execution_labels_error_trim(script_runner, monkeypatch):
'--resample', '0.2', '--compress', '0.1'
'--trim_endpoints')
assert not ret.success


def test_execution_labels_one_points(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_tractogram = os.path.join(SCILPY_HOME, 'connectivity',
'bundle_all_1mm.trk')
in_labels = os.path.join(SCILPY_HOME, 'connectivity',
'endpoints_atlas.nii.gz')
ret = script_runner.run('scil_tractogram_cut_streamlines.py',
in_tractogram, '--labels', in_labels,
'out_tractogram_cut.trk', '-f',
'--no_point_in_roi', '--label_ids', '1', '10',)
assert ret.success


def test_execution_labels_one_point(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_tractogram = os.path.join(SCILPY_HOME, 'connectivity',
'bundle_all_1mm.trk')
in_labels = os.path.join(SCILPY_HOME, 'connectivity',
'endpoints_atlas.nii.gz')
ret = script_runner.run('scil_tractogram_cut_streamlines.py',
in_tractogram, '--labels', in_labels,
'out_tractogram_cut.trk', '-f',
'--one_point_in_roi', '--label_ids', '1', '10',)
assert ret.success
Loading