-
Notifications
You must be signed in to change notification settings - Fork 204
/
Copy pathsegmentation.py
163 lines (144 loc) · 6.2 KB
/
segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
from lib.transforms.transforms import NormalizeLabelsInDatasetd
from monai.handlers import TensorBoardImageHandler, from_engine
from monai.inferers import SlidingWindowInferer
from monai.losses import DiceCELoss
from monai.transforms import (
Activationsd,
AsDiscreted,
CropForegroundd,
EnsureChannelFirstd,
EnsureTyped,
GaussianSmoothd,
LoadImaged,
NormalizeIntensityd,
Orientationd,
RandSpatialCropd,
ScaleIntensityd,
SelectItemsd,
Spacingd,
)
from monailabel.tasks.train.basic_train import BasicTrainTask, Context
from monailabel.tasks.train.utils import region_wise_metrics
logger = logging.getLogger(__name__)
class Segmentation(BasicTrainTask):
def __init__(
self,
model_dir,
network,
roi_size=(96, 96, 96),
target_spacing=(1.0, 1.0, 1.0),
num_samples=4,
description="Train Segmentation model",
**kwargs,
):
self._network = network
self.roi_size = roi_size
self.target_spacing = target_spacing
self.num_samples = num_samples
super().__init__(model_dir, description, **kwargs)
def network(self, context: Context):
return self._network
def optimizer(self, context: Context):
return torch.optim.AdamW(context.network.parameters(), lr=1e-4, weight_decay=1e-5)
def loss_function(self, context: Context):
return DiceCELoss(to_onehot_y=True, softmax=True)
def lr_scheduler_handler(self, context: Context):
return None
def train_data_loader(self, context, num_workers=0, shuffle=False):
return super().train_data_loader(context, num_workers, True)
def train_pre_transforms(self, context: Context):
return [
LoadImaged(keys=("image", "label")),
NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), # Specially for missing labels
EnsureChannelFirstd(keys=("image", "label")),
EnsureTyped(keys=("image", "label"), device=context.device),
Orientationd(keys=("image", "label"), axcodes="RAS"),
Spacingd(keys=("image", "label"), pixdim=self.target_spacing, mode=("bilinear", "nearest")),
NormalizeIntensityd(keys="image", nonzero=True),
CropForegroundd(
keys=("image", "label"),
source_key="image",
margin=10,
k_divisible=[self.roi_size[0], self.roi_size[1], self.roi_size[2]],
),
GaussianSmoothd(keys="image", sigma=0.4),
ScaleIntensityd(keys="image", minv=-1.0, maxv=1.0),
RandSpatialCropd(
keys=["image", "label"],
roi_size=[self.roi_size[0], self.roi_size[1], self.roi_size[2]],
random_size=False,
),
SelectItemsd(keys=("image", "label")),
]
def train_post_transforms(self, context: Context):
return [
EnsureTyped(keys="pred", device=context.device),
Activationsd(keys="pred", softmax=True),
AsDiscreted(
keys=("pred", "label"),
argmax=(True, False),
to_onehot=len(self._labels) + 1,
),
]
def val_pre_transforms(self, context: Context):
return [
LoadImaged(keys=("image", "label")),
NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), # Specially for missing labels
EnsureTyped(keys=("image", "label")),
EnsureChannelFirstd(keys=("image", "label")),
Orientationd(keys=("image", "label"), axcodes="RAS"),
Spacingd(keys=("image", "label"), pixdim=self.target_spacing, mode=("bilinear", "nearest")),
NormalizeIntensityd(keys="image", nonzero=True),
# ScaleIntensityRanged(keys="image", a_min=-1000, a_max=1900, b_min=0.0, b_max=1.0, clip=True),
CropForegroundd(
keys=("image", "label"),
source_key="label",
margin=10,
k_divisible=[self.roi_size[0], self.roi_size[1], self.roi_size[2]],
),
GaussianSmoothd(keys="image", sigma=0.4),
ScaleIntensityd(keys="image", minv=-1.0, maxv=1.0),
SelectItemsd(keys=("image", "label")),
]
def val_inferer(self, context: Context):
return SlidingWindowInferer(
roi_size=self.roi_size, sw_batch_size=2, overlap=0.4, padding_mode="replicate", mode="gaussian"
)
def norm_labels(self):
# This should be applied along with NormalizeLabelsInDatasetd transform
new_label_nums = {}
for idx, (key_label, val_label) in enumerate(self._labels.items(), start=1):
if key_label != "background":
new_label_nums[key_label] = idx
if key_label == "background":
new_label_nums["background"] = 0
return new_label_nums
def train_key_metric(self, context: Context):
return region_wise_metrics(self.norm_labels(), "train_mean_dice", "train")
def val_key_metric(self, context: Context):
return region_wise_metrics(self.norm_labels(), "val_mean_dice", "val")
def train_handlers(self, context: Context):
handlers = super().train_handlers(context)
if context.local_rank == 0:
handlers.append(
TensorBoardImageHandler(
log_dir=context.events_dir,
batch_transform=from_engine(["image", "label"]),
output_transform=from_engine(["pred"]),
interval=20,
epoch_level=True,
)
)
return handlers