Skip to content

Commit

Permalink
Recursive Item Mapping for Nested Lists in Compose (Project-MONAI#8187)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#8186.

### Description

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: YunLiu <[email protected]>
Co-authored-by: Ben Murray <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored and thibaultdvx committed Feb 13, 2025
1 parent 2759275 commit a33a212
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 17 deletions.
39 changes: 28 additions & 11 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
def execute_compose(
data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
transforms: Sequence[Any],
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
start: int = 0,
end: int | None = None,
Expand All @@ -65,8 +65,13 @@ def execute_compose(
Args:
data: a tensor-like object to be transformed
transforms: a sequence of transforms to be carried out
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
defaults to `True`.
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
it can behave as follows:
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
to the first level of items in `data`.
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
should be recursively applied. This allows treating multi-sample transforms applied after another
multi-sample transform while controlling how deep the mapping goes.
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
defaults to `False`.
start: the index of the first transform to be executed. If not set, this defaults to 0
Expand Down Expand Up @@ -205,8 +210,14 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
Args:
transforms: sequence of callables.
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
defaults to `True`.
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
it can behave as follows:
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
to the first level of items in `data`.
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
should be recursively applied. This allows treating multi-sample transforms applied after another
multi-sample transform while controlling how deep the mapping goes.
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
defaults to `False`.
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
Expand All @@ -227,7 +238,7 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
def __init__(
self,
transforms: Sequence[Callable] | Callable | None = None,
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
log_stats: bool | str = False,
lazy: bool | None = False,
Expand All @@ -238,9 +249,9 @@ def __init__(
if transforms is None:
transforms = []

if not isinstance(map_items, bool):
if not isinstance(map_items, (bool, int)):
raise ValueError(
f"Argument 'map_items' should be boolean. Got {type(map_items)}."
f"Argument 'map_items' should be boolean or int. Got {type(map_items)}."
"Check brackets when passing a sequence of callables."
)

Expand Down Expand Up @@ -391,8 +402,14 @@ class OneOf(Compose):
transforms: sequence of callables.
weights: probabilities corresponding to each callable in transforms.
Probabilities are normalized to sum to one.
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
defaults to `True`.
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
it can behave as follows:
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
to the first level of items in `data`.
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
should be recursively applied. This allows treating multi-sample transforms applied after another
multi-sample transform while controlling how deep the mapping goes.
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
defaults to `False`.
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
Expand All @@ -414,7 +431,7 @@ def __init__(
self,
transforms: Sequence[Callable] | Callable | None = None,
weights: Sequence[float] | float | None = None,
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
log_stats: bool | str = False,
lazy: bool | None = False,
Expand Down
21 changes: 15 additions & 6 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ def _apply_transform(
def apply_transform(
transform: Callable[..., ReturnType],
data: Any,
map_items: bool = True,
map_items: bool | int = True,
unpack_items: bool = False,
log_stats: bool | str = False,
lazy: bool | None = None,
overrides: dict | None = None,
) -> list[ReturnType] | ReturnType:
) -> list[Any] | ReturnType:
"""
Transform `data` with `transform`.
Expand All @@ -117,8 +117,13 @@ def apply_transform(
Args:
transform: a callable to be used to transform `data`.
data: an object to be transformed.
map_items: whether to apply transform to each item in `data`,
if `data` is a list or tuple. Defaults to True.
map_items: controls whether to apply a transformation to each item in `data`. If `data` is a list or tuple,
it can behave as follows:
- Defaults to True, which is equivalent to `map_items=1`, meaning the transformation will be applied
to the first level of items in `data`.
- If an integer is provided, it specifies the maximum level of nesting to which the transformation
should be recursively applied. This allows treating multi-sample transforms applied after another
multi-sample transform while controlling how deep the mapping goes.
unpack_items: whether to unpack parameters using `*`. Defaults to False.
log_stats: log errors when they occur in the processing pipeline. By default, this is set to False, which
disables the logger for processing pipeline errors. Setting it to None or True will enable logging to the
Expand All @@ -136,8 +141,12 @@ def apply_transform(
Union[List[ReturnType], ReturnType]: The return type of `transform` or a list thereof.
"""
try:
if isinstance(data, (list, tuple)) and map_items:
return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
if isinstance(data, (list, tuple)) and map_items_ > 0:
return [
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
for item in data
]
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
except Exception as e:
# if in debug mode, don't swallow exception so that the breakpoint
Expand Down
14 changes: 14 additions & 0 deletions tests/transforms/compose/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ def b(i, i2):
self.assertEqual(mt.Compose(transforms, unpack_items=True)(data), expected)
self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected)

def test_list_non_dict_compose_with_unpack_map_2(self):

def a(i, i2):
return i + "a", i2 + "a2"

def b(i, i2):
return i + "b", i2 + "b2"

transforms = [a, b, a, b]
data = [[("", ""), ("", "")], [("t", "t"), ("t", "t")]]
expected = [[("abab", "a2b2a2b2"), ("abab", "a2b2a2b2")], [("tabab", "ta2b2a2b2"), ("tabab", "ta2b2a2b2")]]
self.assertEqual(mt.Compose(transforms, map_items=2, unpack_items=True)(data), expected)
self.assertEqual(execute_compose(data, transforms, map_items=2, unpack_items=True), expected)

def test_list_dict_compose_no_map(self):

def a(d): # transform to handle dict data
Expand Down

0 comments on commit a33a212

Please sign in to comment.