From f4e71c07c04d5eb88dcf7399d70a27cae1679d84 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 23 Jan 2023 15:58:02 +0000 Subject: [PATCH] simplify replace Signed-off-by: Wenqi Li --- monai/transforms/inverse.py | 17 ++++++++-- monai/transforms/spatial/array.py | 38 +++------------------- monai/transforms/spatial/dictionary.py | 44 ++++---------------------- 3 files changed, 26 insertions(+), 73 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index efa3270479a..74f4a509948 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -87,14 +87,25 @@ def get_transform_info(self) -> dict: TraceKeys.DO_TRANSFORM: self._do_transform if hasattr(self, "_do_transform") else False, } - def push_transform(self, *args, **kwargs): + def push_transform(self, data, *args, **kwargs): transform_info = self.get_transform_info() + lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False) + do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, False) if not kwargs: kwargs = {} kwargs["transform_info"] = transform_info + replace = kwargs.pop("replace", False) + if replace and isinstance(data, MetaTensor) and get_track_meta(): + if not lazy_eval: + xform = self.pop_transform(data, check=False) if do_transform else {} + return self.push_transform(data, extra_info=xform) + elif do_transform: + return self.push_transform(data, pending=data.pending_operations.pop()) # type: ignore + else: + return data if transform_info.get(TraceKeys.LAZY_EVALUATION, False): - return TraceableTransform.track_pending_transform(*args, **kwargs) - return TraceableTransform.track_transform(*args, **kwargs) + return TraceableTransform.track_pending_transform(data, *args, **kwargs) + return TraceableTransform.track_transform(data, *args, **kwargs) @classmethod def track_transform( diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 0373bae8213..469e929a2b5 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -1131,12 +1131,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: else: out = convert_to_tensor(img, track_meta=get_track_meta()) - if get_track_meta(): - if not self.lazy_evaluation: - maybe_rot90_info = self.pop_transform(out, check=False) if self._do_transform else {} - self.push_transform(out, extra_info=maybe_rot90_info) - elif self._do_transform: - self.push_transform(out, pending=out.pending_operations.pop()) # type: ignore + self.push_transform(out, replace=True) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1261,13 +1256,7 @@ def __call__( out = rotator(img) else: out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32) - if get_track_meta(): - if not self.lazy_evaluation: - rot_info = self.pop_transform(out, check=False) if self._do_transform else {} - self.push_transform(out, extra_info=rot_info) - elif self._do_transform: - p = out.pending_operations.pop() # type: ignore - self.push_transform(out, pending=p) + self.push_transform(out, replace=True) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1309,13 +1298,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: self.randomize(None) out = self.flipper(img) if self._do_transform else img out = convert_to_tensor(out, track_meta=get_track_meta()) - if get_track_meta(): - if not self.lazy_evaluation: - xform_info = self.pop_transform(out, check=False) if self._do_transform else {} - self.push_transform(out, extra_info=xform_info) - elif self._do_transform: - p = out.pending_operations.pop() # type: ignore - self.push_transform(out, pending=p) + self.push_transform(out, replace=True) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1369,12 +1352,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: out = self.flipper(img) else: out = convert_to_tensor(img, track_meta=get_track_meta()) - if get_track_meta(): - if not self.lazy_evaluation: - xform = self.pop_transform(out, check=False) if self._do_transform else {} - self.push_transform(out, extra_info=xform) - elif self._do_transform: - self.push_transform(out, pending=out.pending_operations.pop()) # type: ignore + self.push_transform(out, replace=True) return out def inverse(self, data: torch.Tensor) -> torch.Tensor: @@ -1503,13 +1481,7 @@ def __call__( ) xform.lazy_evaluation = self.lazy_evaluation out = xform(img) - if get_track_meta(): - if not self.lazy_evaluation: - z_info = self.pop_transform(out, check=False) if self._do_transform else {} - self.push_transform(out, extra_info=z_info) - elif self._do_transform: - p = out.pending_operations.pop() - self.push_transform(out, pending=p) + self.push_transform(out, replace=True) return out # type: ignore def inverse(self, data: torch.Tensor) -> torch.Tensor: diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index adcda2babc2..a15dda7ae9d 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -598,13 +598,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t rotator.lazy_evaluation = self.lazy_evaluation for key in self.key_iterator(d): d[key] = rotator(d[key]) if self._do_transform else convert_to_tensor(d[key], track_meta=get_track_meta()) - if get_track_meta(): - if not self.lazy_evaluation: - xform = self.pop_transform(d[key], check=False) if self._do_transform else {} - self.push_transform(d[key], extra_info=xform) - elif self._do_transform: - self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore - + self.push_transform(d[key], replace=True) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -942,12 +936,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N d[key] = self.rand_affine(d[key], mode=mode, padding_mode=padding_mode, grid=grid) # type: ignore else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - if get_track_meta(): - if not self.lazy_evaluation: - xform = self.pop_transform(d[key], check=False) if do_resampling else {} - self.push_transform(d[key], extra_info=xform) - elif do_resampling and isinstance(d[key], MetaTensor): - self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore + self._do_transform = do_resampling # TODO: unify self._do_transform and do_resampling + self.push_transform(d[key], replace=True) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: @@ -1320,12 +1310,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc d[key] = self.flipper(d[key]) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) - if get_track_meta(): - if not self.lazy_evaluation: - xform_info = self.pop_transform(d[key], check=False) if self._do_transform else {} - self.push_transform(d[key], extra_info=xform_info) - elif self._do_transform: - self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore + self.push_transform(d[key], replace=True) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1386,12 +1371,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc d[key] = self.flipper(d[key], randomize=False) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) - if get_track_meta(): - if not self.lazy_evaluation: - xform = self.pop_transform(d[key], check=False) if self._do_transform else {} - self.push_transform(d[key], extra_info=xform) - elif self._do_transform: - self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore + self.push_transform(d[key], replace=True) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1564,12 +1544,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc ) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - if get_track_meta(): - if not self.lazy_evaluation: - rot_info = self.pop_transform(d[key], check=False) if self._do_transform else {} - self.push_transform(d[key], extra_info=rot_info) - elif self._do_transform: - self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore + self.push_transform(d[key], replace=True) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: @@ -1744,12 +1719,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc ) else: d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32) - if get_track_meta(): - if not self.lazy_evaluation: - xform = self.pop_transform(d[key], check=False) if self._do_transform else {} - self.push_transform(d[key], extra_info=xform) - elif self._do_transform: - self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore + self.push_transform(d[key], replace=True) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: