Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MetaTensor: re-enable full tests #4462

Closed
12 tasks done
wyli opened this issue Jun 6, 2022 · 3 comments · Fixed by #4729
Closed
12 tasks done

MetaTensor: re-enable full tests #4462

wyli opened this issue Jun 6, 2022 · 3 comments · Fixed by #4729

Comments

@wyli
Copy link
Contributor

wyli commented Jun 6, 2022

Is your feature request related to a problem? Please describe.
Follow-up of prototype #4371

This ticket is to track the necessary breaking changes on branch feature/MetaTensor when rolling out the MetaTensor support for all the monai.transforms.croppad.

These tests/checks are temporarily muted and should be re-enabled before merging the new features to the dev branch:

cc @rijobro

(latest successful build is at 60a22ff)

@wyli
Copy link
Contributor Author

wyli commented Jun 14, 2022

(update: this is addressed by #4506)

tests/test_integration_fast_train.py is not currently leveraging MetaTensor, an Ensuretyped was included to convert the metatensors into regular tensors:

EnsureTyped(keys=["image", "label"], drop_meta=True),

The primary issue was in network forward performance, I looked into a smaller example to replicate it:

import shutil                                                                                                                                                                                                                                                                                                             
import os                                                                                                                                                                                                                                                                                                                 
from glob import glob                                                                                                                                                                                                                                                                                                              
import time                                                                                                                                                                                                                                                                                                               
import nibabel as nib                                                                                                                                                                                                                                                                                                     
import tempfile                                                                                                                                                                                                                                                                                                           
import unittest                                                                                                                                                                                                                                                                                                           
import numpy as np                                                                                                                                                                                                                                                                                                        
import torch                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                               
import monai                                                                                                                                                                                                                                                                                                                                                   
from monai.data import CacheDataset, ThreadDataLoader, create_test_image_3d                                                                                                                                                                                                                                                                                    
from monai.networks.nets import UNet                                                                                                                                                                                                                                                                                                                           
from monai.transforms import (                                                                                                                                                                                                                                                                                                                                 
    Compose,                                                                                                                                                                                                                                                                                                                                                   
    EnsureChannelFirstd,                                                                                                                                                                                                                                                                                                                                       
    LoadImaged,                                                                                                                                                                                                                                                                                                                                                
    ToDeviced,                                                                                                                                                                                                                                                                                                                                                 
)                                                                                                                                                                                                                                                                                                                                                              
from monai.utils import set_determinism                                                                                                                                                                                                                                                                                                                        
from tests.utils import skip_if_no_cuda                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                               
                                                                                                                                                                                                                                                                                                                                                               
@skip_if_no_cuda                                                                                                                                                                                                                                                                                                                                               
class IntegrationFastTrain(unittest.TestCase):                                                                                                                                                                                                                                                                                                                 
    def setUp(self):                                                                                                                                                                                                                                                                                                                                           
        set_determinism(seed=0)                                                                                                                                                                                                                                                                                                                                
        monai.config.print_config()                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                               
        self.data_dir = tempfile.mkdtemp()                                                                                                                                                                                                                                                                                                                     
        for i in range(10):                                                                                                                                                                                                                                                                                                                                    
            im, seg = create_test_image_3d(64, 64, 64, num_seg_classes=1, channel_dim=-1)                                                                                                                                                                                                                                                                      
            n = nib.Nifti1Image(im, np.eye(4))                                                                                                                                                                                                                                                                                                                 
            nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz"))                                                                                                                                                                                                                                                                                       
            n = nib.Nifti1Image(seg, np.eye(4))                                                                                                                                                                                                                                                                                                                
            nib.save(n, os.path.join(self.data_dir, f"seg{i:d}.nii.gz"))                                                                                                                                                                                                                                                                                       
                                                                                                                                                                                                                                                                                                                                                               
    def tearDown(self):                                                                                                                                                                                                                                                                                                                                        
        set_determinism(seed=None)                                                                                                                                                                                                                                                                                                                             
        shutil.rmtree(self.data_dir)                                                                                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                                                                                               
    def test_train_timing(self, use_metatensor=False):                                                                                                                                                                                                                                                                                                         
        images = sorted(glob(os.path.join(self.data_dir, "img*.nii.gz")))                                                                                                                                                                                                                                                                                      
        segs = sorted(glob(os.path.join(self.data_dir, "seg*.nii.gz")))                                                                                                                                                                                                                                                                                        
        train_files = [{"image": img, "label": seg} for img, seg in zip(images, segs)]                                                                                                                                                                                                                                                                         
        device = torch.device("cuda:0")                                                                                                                                                                                                                                                                                                                        
        train_transforms = Compose(                                                                                                                                                                                                                                                                                                                            
            [                                                                                                                                                                                                                                                                                                                                                  
                LoadImaged(keys=["image", "label"]),                                                                                                                                                                                                                                                                                                           
                EnsureChannelFirstd(keys=["image", "label"]),                                                                                                                                                                                                                                                                                                  
                ToDeviced(keys=["image", "label"], device=device),                                                                                                                                                                                                                                                                                             
            ]                                                                                                                                                                                                                                                                                                                                                  
        )                                                                                                                                                                                                                                                                                                                                                      
        train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=0)                                                                                                                                                                                                                                                   
        train_loader = ThreadDataLoader(train_ds, num_workers=0, batch_size=10, shuffle=True)                                                                                                                                                                                                                                                                  
        model = UNet(                                                                                                                                                                                                                                                                                                                                          
            spatial_dims=3,                                                                                                                                                                                                                                                                                                                                    
            in_channels=1,                                                                                                                                                                                                                                                                                                                                     
            out_channels=2,                                                                                                                                                                                                                                                                                                                                    
            channels=(16, 32, 64, 128, 256),                                                                                                                                                                                                                                                                                                                   
            strides=(2, 2, 2, 2),                                                                                                                                                                                                                                                                                                                              
        ).to(device)                                                                                                                                                                                                                                                                                                                                           
        for _ in range(5):                                                                                                                                                                                                                                                                                                                                     
            for batch_data in train_loader:                                                                                                                                                                                                                                                                                                                    
                image = batch_data['image']                                                                                                                                                                                                                                                                                                                    
                if not use_metatensor:                                                                                                                                                                                                                                                                                                                         
                    image = image.as_tensor()                                                                                                                                                                                                                                                                                                                  
                step_start = time.time()                                                                                                                                                                                                                                                                                                                       
                outputs = model(image)                                                                                                                                                                                                                                                                                                                         
                xx = time.time()                                                                                                                                                                                                                                                                                                                               
                print(outputs.shape, 'iter forward', xx - step_start)                                                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                                                                                                               
                                                                                                                                                                                                                                                                                                                                                               
if __name__ == "__main__":                                                                                                                                                                                                                                                                                                                                     
    unittest.main()      

log with use_metatensor=False:

Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 76.09it/s]
torch.Size([10, 2, 64, 64, 64]) iter forward 0.9106705188751221
torch.Size([10, 2, 64, 64, 64]) iter forward 0.009454011917114258
torch.Size([10, 2, 64, 64, 64]) iter forward 0.005485057830810547
torch.Size([10, 2, 64, 64, 64]) iter forward 0.004019260406494141
torch.Size([10, 2, 64, 64, 64]) iter forward 0.003930091857910156
.
----------------------------------------------------------------------
Ran 1 test in 1.338s

log with use_metatensor=True:

Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 88.51it/s]
(10, 2, 64, 64, 64) iter forward 1.0120453834533691
(10, 2, 64, 64, 64) iter forward 0.1152486801147461
(10, 2, 64, 64, 64) iter forward 0.13366961479187012
(10, 2, 64, 64, 64) iter forward 0.1290879249572754
(10, 2, 64, 64, 64) iter forward 0.12228989601135254
.
----------------------------------------------------------------------
Ran 1 test in 1.911s

OK

@wyli wyli mentioned this issue Jun 15, 2022
7 tasks
@wyli
Copy link
Contributor Author

wyli commented Jun 28, 2022

the inverse operation of "crop samples" transforms are currently dropped and need some discussions (#4548 (comment))

@wyli
Copy link
Contributor Author

wyli commented Jul 5, 2022

idea from discussions introduce "strict" and "dev" modes to control manual modifications of the metadata

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant