Skip to content

Commit

Permalink
simplify replace
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Jan 23, 2023
1 parent bffc760 commit f4e71c0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 73 deletions.
17 changes: 14 additions & 3 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,25 @@ def get_transform_info(self) -> dict:
TraceKeys.DO_TRANSFORM: self._do_transform if hasattr(self, "_do_transform") else False,
}

def push_transform(self, *args, **kwargs):
def push_transform(self, data, *args, **kwargs):
transform_info = self.get_transform_info()
lazy_eval = transform_info.get(TraceKeys.LAZY_EVALUATION, False)
do_transform = transform_info.get(TraceKeys.DO_TRANSFORM, False)
if not kwargs:
kwargs = {}
kwargs["transform_info"] = transform_info
replace = kwargs.pop("replace", False)
if replace and isinstance(data, MetaTensor) and get_track_meta():
if not lazy_eval:
xform = self.pop_transform(data, check=False) if do_transform else {}
return self.push_transform(data, extra_info=xform)
elif do_transform:
return self.push_transform(data, pending=data.pending_operations.pop()) # type: ignore
else:
return data
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
return TraceableTransform.track_pending_transform(*args, **kwargs)
return TraceableTransform.track_transform(*args, **kwargs)
return TraceableTransform.track_pending_transform(data, *args, **kwargs)
return TraceableTransform.track_transform(data, *args, **kwargs)

@classmethod
def track_transform(
Expand Down
38 changes: 5 additions & 33 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,12 +1131,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
else:
out = convert_to_tensor(img, track_meta=get_track_meta())

if get_track_meta():
if not self.lazy_evaluation:
maybe_rot90_info = self.pop_transform(out, check=False) if self._do_transform else {}
self.push_transform(out, extra_info=maybe_rot90_info)
elif self._do_transform:
self.push_transform(out, pending=out.pending_operations.pop()) # type: ignore
self.push_transform(out, replace=True)
return out

def inverse(self, data: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -1261,13 +1256,7 @@ def __call__(
out = rotator(img)
else:
out = convert_to_tensor(img, track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
if not self.lazy_evaluation:
rot_info = self.pop_transform(out, check=False) if self._do_transform else {}
self.push_transform(out, extra_info=rot_info)
elif self._do_transform:
p = out.pending_operations.pop() # type: ignore
self.push_transform(out, pending=p)
self.push_transform(out, replace=True)
return out

def inverse(self, data: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -1309,13 +1298,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
self.randomize(None)
out = self.flipper(img) if self._do_transform else img
out = convert_to_tensor(out, track_meta=get_track_meta())
if get_track_meta():
if not self.lazy_evaluation:
xform_info = self.pop_transform(out, check=False) if self._do_transform else {}
self.push_transform(out, extra_info=xform_info)
elif self._do_transform:
p = out.pending_operations.pop() # type: ignore
self.push_transform(out, pending=p)
self.push_transform(out, replace=True)
return out

def inverse(self, data: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -1369,12 +1352,7 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
out = self.flipper(img)
else:
out = convert_to_tensor(img, track_meta=get_track_meta())
if get_track_meta():
if not self.lazy_evaluation:
xform = self.pop_transform(out, check=False) if self._do_transform else {}
self.push_transform(out, extra_info=xform)
elif self._do_transform:
self.push_transform(out, pending=out.pending_operations.pop()) # type: ignore
self.push_transform(out, replace=True)
return out

def inverse(self, data: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -1503,13 +1481,7 @@ def __call__(
)
xform.lazy_evaluation = self.lazy_evaluation
out = xform(img)
if get_track_meta():
if not self.lazy_evaluation:
z_info = self.pop_transform(out, check=False) if self._do_transform else {}
self.push_transform(out, extra_info=z_info)
elif self._do_transform:
p = out.pending_operations.pop()
self.push_transform(out, pending=p)
self.push_transform(out, replace=True)
return out # type: ignore

def inverse(self, data: torch.Tensor) -> torch.Tensor:
Expand Down
44 changes: 7 additions & 37 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,13 +598,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t
rotator.lazy_evaluation = self.lazy_evaluation
for key in self.key_iterator(d):
d[key] = rotator(d[key]) if self._do_transform else convert_to_tensor(d[key], track_meta=get_track_meta())
if get_track_meta():
if not self.lazy_evaluation:
xform = self.pop_transform(d[key], check=False) if self._do_transform else {}
self.push_transform(d[key], extra_info=xform)
elif self._do_transform:
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore

self.push_transform(d[key], replace=True)
return d

def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
Expand Down Expand Up @@ -942,12 +936,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
d[key] = self.rand_affine(d[key], mode=mode, padding_mode=padding_mode, grid=grid) # type: ignore
else:
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
if not self.lazy_evaluation:
xform = self.pop_transform(d[key], check=False) if do_resampling else {}
self.push_transform(d[key], extra_info=xform)
elif do_resampling and isinstance(d[key], MetaTensor):
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
self._do_transform = do_resampling # TODO: unify self._do_transform and do_resampling
self.push_transform(d[key], replace=True)
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
Expand Down Expand Up @@ -1320,12 +1310,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
d[key] = self.flipper(d[key])
else:
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
if get_track_meta():
if not self.lazy_evaluation:
xform_info = self.pop_transform(d[key], check=False) if self._do_transform else {}
self.push_transform(d[key], extra_info=xform_info)
elif self._do_transform:
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
self.push_transform(d[key], replace=True)
return d

def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
Expand Down Expand Up @@ -1386,12 +1371,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
d[key] = self.flipper(d[key], randomize=False)
else:
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
if get_track_meta():
if not self.lazy_evaluation:
xform = self.pop_transform(d[key], check=False) if self._do_transform else {}
self.push_transform(d[key], extra_info=xform)
elif self._do_transform:
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
self.push_transform(d[key], replace=True)
return d

def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
Expand Down Expand Up @@ -1564,12 +1544,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
)
else:
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
if not self.lazy_evaluation:
rot_info = self.pop_transform(d[key], check=False) if self._do_transform else {}
self.push_transform(d[key], extra_info=rot_info)
elif self._do_transform:
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
self.push_transform(d[key], replace=True)
return d

def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
Expand Down Expand Up @@ -1744,12 +1719,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torc
)
else:
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
if get_track_meta():
if not self.lazy_evaluation:
xform = self.pop_transform(d[key], check=False) if self._do_transform else {}
self.push_transform(d[key], extra_info=xform)
elif self._do_transform:
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
self.push_transform(d[key], replace=True)
return d

def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
Expand Down

0 comments on commit f4e71c0

Please sign in to comment.