Skip to content

Commit

Permalink
fixes tests
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli committed Jan 25, 2023
1 parent 7699634 commit 45631e0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 15 deletions.
2 changes: 2 additions & 0 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def copy_meta_from(self, input_objs, copy_attr=True, keys=None):
return self with the updated ``__dict__``.
"""
first_meta = input_objs if isinstance(input_objs, MetaObj) else first(input_objs, default=self)
if not hasattr(first_meta, "__dict__"):
return self
first_meta = first_meta.__dict__
keys = first_meta.keys() if keys is None else keys
if not copy_attr:
Expand Down
21 changes: 6 additions & 15 deletions tests/test_traceable_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,19 @@

from __future__ import annotations

import torch
import unittest

from monai.transforms.inverse import TraceableTransform


class _TraceTest(TraceableTransform):
def __call__(self, data):
self.push_transform(data)
self.push_transform(data, "image")
return data

def pop(self, data):
self.pop_transform(data)
self.pop_transform(data, "image")
return data


Expand All @@ -34,21 +35,11 @@ def test_default(self):

data = {"image": "test"}
data = a(data) # adds to the stack
self.assertTrue(isinstance(data[expected_key], list))
self.assertEqual(data[expected_key][0]["class"], "_TraceTest")
self.assertEqual(data["image"], "test")

data = {"image": torch.tensor(1.0)}
data = a(data) # adds to the stack
self.assertEqual(len(data[expected_key]), 2)
self.assertEqual(data[expected_key][-1]["class"], "_TraceTest")

with self.assertRaises(IndexError):
a.pop({"test": "test"}) # no stack in the data
data = a.pop(data)
data = a.pop(data)
self.assertEqual(data[expected_key], [])

with self.assertRaises(IndexError): # no more items
a.pop(data)
self.assertEqual(data["image"].applied_operations[0]['class'], "_TraceTest")


if __name__ == "__main__":
Expand Down

0 comments on commit 45631e0

Please sign in to comment.