Skip to content

Commit

Permalink
update xform applied
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Jan 23, 2023
1 parent 65b4b6c commit bffc760
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 5 additions & 3 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
if get_track_meta():
if not self.lazy_evaluation:
xform = self.pop_transform(d[key], check=False) if do_resampling else {}
self.push_transform(d[key], extra_info={"do_resampling": do_resampling, "rand_affine_info": xform})
self.push_transform(d[key], extra_info=xform)
elif do_resampling and isinstance(d[key], MetaTensor):
self.push_transform(d[key], pending=d[key].pending_operations.pop()) # type: ignore
return d
Expand All @@ -954,9 +954,11 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, Nd
d = dict(data)
for key in self.key_iterator(d):
tr = self.pop_transform(d[key])
do_resampling = tr[TraceKeys.EXTRA_INFO]["do_resampling"]
if TraceKeys.EXTRA_INFO not in tr[TraceKeys.EXTRA_INFO]:
continue
do_resampling = tr[TraceKeys.EXTRA_INFO][TraceKeys.EXTRA_INFO]["do_resampling"]
if do_resampling:
d[key].applied_operations.append(tr[TraceKeys.EXTRA_INFO]["rand_affine_info"]) # type: ignore
d[key].applied_operations.append(tr[TraceKeys.EXTRA_INFO]) # type: ignore
d[key] = self.rand_affine.inverse(d[key]) # type: ignore

return d
Expand Down
8 changes: 5 additions & 3 deletions tests/test_rand_affined.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,12 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta):

# affine should be tensor because the resampler only supports pytorch backend
if isinstance(res["img"], MetaTensor) and "extra_info" in res["img"].applied_operations[0]:
if not res["img"].applied_operations[-1]["extra_info"]["do_resampling"]:
if not res["img"].applied_operations[-1]["extra_info"]:
return
affine_img = res["img"].applied_operations[0]["extra_info"]["rand_affine_info"]["extra_info"]["affine"]
affine_seg = res["seg"].applied_operations[0]["extra_info"]["rand_affine_info"]["extra_info"]["affine"]
if not res["img"].applied_operations[-1]["extra_info"]["extra_info"]["do_resampling"]:
return
affine_img = res["img"].applied_operations[0]["extra_info"]["extra_info"]["affine"]
affine_seg = res["seg"].applied_operations[0]["extra_info"]["extra_info"]["affine"]
assert_allclose(affine_img, affine_seg, rtol=_rtol, atol=1e-3)

res_inv = g.inverse(res)
Expand Down

0 comments on commit bffc760

Please sign in to comment.