|
1 | 1 | # %%
|
2 | 2 | import torch
|
3 |
| -from viscy.data.hcs import HCSDataModule |
4 |
| - |
5 |
| -import numpy as np |
6 |
| -import torch.nn as nn |
7 | 3 | import lightning.pytorch as pl
|
8 |
| -import torch.nn.functional as F |
9 |
| -import torchmetrics |
10 |
| - |
11 |
| -# import torchview |
12 |
| -from typing import Literal, Sequence |
13 |
| -from skimage.exposure import rescale_intensity |
14 |
| -from sklearn.metrics import ConfusionMatrixDisplay |
15 |
| -from matplotlib.cm import get_cmap |
| 4 | +import torch.nn as nn |
16 | 5 |
|
17 |
| -# import napari |
18 | 6 | from pytorch_lightning.loggers import TensorBoardLogger
|
19 |
| -from torch import Tensor |
20 | 7 | from pytorch_lightning.callbacks import ModelCheckpoint
|
21 | 8 |
|
22 |
| -# from monai.losses import DiceLoss |
23 |
| -from monai.transforms import DivisiblePad |
24 |
| -from skimage.measure import regionprops |
25 |
| - |
26 |
| -# from viscy.light.engine import VSUNet |
27 |
| -from viscy.unet.networks.Unet2D import Unet2d |
28 |
| -from viscy.data.hcs import Sample |
29 |
| -from viscy.transforms import RandWeightedCropd, RandGaussianNoised |
| 9 | +from viscy.transforms import RandWeightedCropd |
30 | 10 | from viscy.transforms import NormalizeSampled
|
| 11 | +from viscy.data.hcs import HCSDataModule |
| 12 | +from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D |
31 | 13 |
|
32 | 14 | # %% Create a dataloader and visualize the batches.
|
33 | 15 |
|
|
91 | 73 | # # Start the napari event loop
|
92 | 74 | # napari.run()
|
93 | 75 |
|
94 |
| -# %% |
95 |
| - |
96 |
| -# Define a 2D UNet model for semantic segmentation as a lightning module. |
97 |
| - |
98 |
| - |
99 |
| -class SemanticSegUNet2D(pl.LightningModule): |
100 |
| - # Model for semantic segmentation. |
101 |
| - def __init__( |
102 |
| - self, |
103 |
| - in_channels: int, # Number of input channels |
104 |
| - out_channels: int, # Number of output channels |
105 |
| - lr: float = 1e-3, # Learning rate |
106 |
| - loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function |
107 |
| - schedule: Literal[ |
108 |
| - "WarmupCosine", "Constant" |
109 |
| - ] = "Constant", # Learning rate schedule |
110 |
| - log_batches_per_epoch: int = 2, # Number of batches to log per epoch |
111 |
| - log_samples_per_batch: int = 2, # Number of samples to log per batch |
112 |
| - checkpoint_path: str = None, # Path to the checkpoint |
113 |
| - ): |
114 |
| - super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer |
115 |
| - # Initialize the UNet model |
116 |
| - self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) |
117 |
| - self.lr = lr # Set the learning rate |
118 |
| - # Set the loss function to CrossEntropyLoss if none is provided |
119 |
| - self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() |
120 |
| - self.schedule = schedule # Set the learning rate schedule |
121 |
| - self.log_batches_per_epoch = ( |
122 |
| - log_batches_per_epoch # Set the number of batches to log per epoch |
123 |
| - ) |
124 |
| - self.log_samples_per_batch = ( |
125 |
| - log_samples_per_batch # Set the number of samples to log per batch |
126 |
| - ) |
127 |
| - self.training_step_outputs = [] # Initialize the list of training step outputs |
128 |
| - self.validation_step_outputs = ( |
129 |
| - [] |
130 |
| - ) # Initialize the list of validation step outputs |
131 |
| - |
132 |
| - self.pred_cm = None # Initialize the confusion matrix |
133 |
| - self.index_to_label_dict = ["Background", "Infected", "Uninfected"] |
134 |
| - |
135 |
| - if checkpoint_path is not None: |
136 |
| - state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[ |
137 |
| - "state_dict" |
138 |
| - ] |
139 |
| - state_dict.pop("loss_function.weight", None) # Remove the unexpected key |
140 |
| - self.load_state_dict(state_dict) # loading only weights |
141 |
| - |
142 |
| - # Define the forward pass |
143 |
| - def forward(self, x): |
144 |
| - return self.unet_model(x) # Pass the input through the UNet model |
145 |
| - |
146 |
| - # Define the optimizer |
147 |
| - def configure_optimizers(self): |
148 |
| - optimizer = torch.optim.Adam( |
149 |
| - self.parameters(), lr=self.lr |
150 |
| - ) # Use the Adam optimizer |
151 |
| - return optimizer |
152 |
| - |
153 |
| - # Define the training step |
154 |
| - def training_step(self, batch: Sample, batch_idx: int): |
155 |
| - source = batch["source"] # Extract the source from the batch |
156 |
| - target = batch["target"] # Extract the target from the batch |
157 |
| - pred = self.forward(source) # Make a prediction using the source |
158 |
| - # Convert the target to one-hot encoding |
159 |
| - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( |
160 |
| - 0, 4, 1, 2, 3 |
161 |
| - ) |
162 |
| - target_one_hot = target_one_hot.float() # Convert the target to float type |
163 |
| - train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss |
164 |
| - # Log the training step outputs if the batch index is less than the number of batches to log per epoch |
165 |
| - if batch_idx < self.log_batches_per_epoch: |
166 |
| - self.training_step_outputs.extend( |
167 |
| - self._detach_sample((source, target_one_hot, pred)) |
168 |
| - ) |
169 |
| - # Log the training loss |
170 |
| - self.log( |
171 |
| - "loss/train", |
172 |
| - train_loss, |
173 |
| - on_step=True, |
174 |
| - on_epoch=True, |
175 |
| - prog_bar=True, |
176 |
| - logger=True, |
177 |
| - sync_dist=True, |
178 |
| - ) |
179 |
| - return train_loss # Return the training loss |
180 |
| - |
181 |
| - def validation_step(self, batch: Sample, batch_idx: int): |
182 |
| - source = batch["source"] # Extract the source from the batch |
183 |
| - target = batch["target"] # Extract the target from the batch |
184 |
| - pred = self.forward(source) # Make a prediction using the source |
185 |
| - # Convert the target to one-hot encoding |
186 |
| - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( |
187 |
| - 0, 4, 1, 2, 3 |
188 |
| - ) |
189 |
| - target_one_hot = target_one_hot.float() # Convert the target to float type |
190 |
| - loss = self.loss_function(pred, target_one_hot) # Calculate the loss |
191 |
| - # Log the validation step outputs if the batch index is less than the number of batches to log per epoch |
192 |
| - if batch_idx < self.log_batches_per_epoch: |
193 |
| - self.validation_step_outputs.extend( |
194 |
| - self._detach_sample((source, target_one_hot, pred)) |
195 |
| - ) |
196 |
| - # Log the validation loss |
197 |
| - self.log( |
198 |
| - "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True |
199 |
| - ) |
200 |
| - return loss # Return the validation loss |
201 |
| - |
202 |
| - def on_predict_start(self): |
203 |
| - """Pad the input shape to be divisible by the downsampling factor. |
204 |
| - The inverse of this transform crops the prediction to original shape. |
205 |
| - """ |
206 |
| - down_factor = 2**self.unet_model.num_blocks |
207 |
| - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) |
208 |
| - |
209 |
| - # Define the prediction step |
210 |
| - def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): |
211 |
| - source = self._predict_pad(batch["source"]) # Pad the source |
212 |
| - target = batch["target"] # Extract the target from the batch |
213 |
| - logits = self._predict_pad.inverse( |
214 |
| - self.forward(source) |
215 |
| - ) # Predict and remove padding. |
216 |
| - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities |
217 |
| - # Go from probabilities/one-hot encoded data to class labels. |
218 |
| - labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels |
219 |
| - labels_target = torch.argmax(target, dim=1) # Calculate the target labels |
220 |
| - # FIXME: Check if compliant with lightning API |
221 |
| - self.pred_cm = confusion_matrix_per_cell( |
222 |
| - labels_target, labels_pred, num_classes=3 |
223 |
| - ) |
224 |
| - |
225 |
| - return prob_pred # log the probabilities instead of logits. |
226 |
| - |
227 |
| - # Accumulate the confusion matrix at the end of prediction epoch and log. |
228 |
| - def on_predict_epoch_end(self): |
229 |
| - confusion_matrix = self.pred_cm.compute().cpu().numpy() |
230 |
| - self.logger.experiment.add_figure( |
231 |
| - "Confusion Matrix per Cell", |
232 |
| - plot_confusion_matrix(confusion_matrix, self.index_to_label_dict), |
233 |
| - self.current_epoch, |
234 |
| - ) |
235 |
| - |
236 |
| - # Define what happens at the end of a training epoch |
237 |
| - def on_train_epoch_end(self): |
238 |
| - self._log_samples( |
239 |
| - "train_samples", self.training_step_outputs |
240 |
| - ) # Log the training samples |
241 |
| - self.training_step_outputs = [] # Reset the list of training step outputs |
242 |
| - |
243 |
| - # Define what happens at the end of a validation epoch |
244 |
| - def on_validation_epoch_end(self): |
245 |
| - self._log_samples( |
246 |
| - "val_samples", self.validation_step_outputs |
247 |
| - ) # Log the validation samples |
248 |
| - self.validation_step_outputs = [] # Reset the list of validation step outputs |
249 |
| - # TODO: Log the confusion matrix |
250 |
| - |
251 |
| - # Define a method to detach a sample |
252 |
| - def _detach_sample(self, imgs: Sequence[Tensor]): |
253 |
| - # Detach the images and convert them to numpy arrays |
254 |
| - num_samples = 3 |
255 |
| - return [ |
256 |
| - [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] |
257 |
| - for i in range(num_samples) |
258 |
| - ] |
259 |
| - |
260 |
| - # Define a method to log samples |
261 |
| - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): |
262 |
| - images_grid = [] # Initialize the list of image grids |
263 |
| - for sample_images in imgs: # For each sample image |
264 |
| - images_row = [] # Initialize the list of image rows |
265 |
| - for i, image in enumerate( |
266 |
| - sample_images |
267 |
| - ): # For each image in the sample images |
268 |
| - cm_name = "gray" if i == 0 else "inferno" # Set the colormap name |
269 |
| - if image.ndim == 2: # If the image is 2D |
270 |
| - image = image[np.newaxis] # Add a new axis |
271 |
| - for channel in image: # For each channel in the image |
272 |
| - channel = rescale_intensity( |
273 |
| - channel, out_range=(0, 1) |
274 |
| - ) # Rescale the intensity of the channel |
275 |
| - render = get_cmap(cm_name)(channel, bytes=True)[ |
276 |
| - ..., :3 |
277 |
| - ] # Render the channel |
278 |
| - images_row.append( |
279 |
| - render |
280 |
| - ) # Append the render to the list of image rows |
281 |
| - images_grid.append( |
282 |
| - np.concatenate(images_row, axis=1) |
283 |
| - ) # Append the concatenated image rows to the list of image grids |
284 |
| - grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids |
285 |
| - # Log the image grid |
286 |
| - self.logger.experiment.add_image( |
287 |
| - key, grid, self.current_epoch, dataformats="HWC" |
288 |
| - ) |
289 |
| - |
290 | 76 |
|
291 | 77 | # %% Define the logger
|
292 | 78 | logger = TensorBoardLogger(
|
@@ -328,105 +114,3 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
|
328 | 114 | # Run training.
|
329 | 115 |
|
330 | 116 | trainer.fit(model, data_module)
|
331 |
| - |
332 |
| -# %% Methods to compute confusion matrix per cell using torchmetrics |
333 |
| - |
334 |
| - |
335 |
| -# The confusion matrix at the single-cell resolution. |
336 |
| -def confusion_matrix_per_cell( |
337 |
| - y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int |
338 |
| -): |
339 |
| - """Compute confusion matrix per cell. |
340 |
| -
|
341 |
| - Args: |
342 |
| - y_true (torch.Tensor): Ground truth label image (BXHXW). |
343 |
| - y_pred (torch.Tensor): Predicted label image (BXHXW). |
344 |
| - num_classes (int): Number of classes. |
345 |
| -
|
346 |
| - Returns: |
347 |
| - torch.Tensor: Confusion matrix per cell (BXCXC). |
348 |
| - """ |
349 |
| - # Convert the image class to the nuclei class |
350 |
| - nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes) |
351 |
| - # Compute the confusion matrix per cell |
352 |
| - confusion_matrix_per_cell = torchmetrics.functional.confusion_matrix( |
353 |
| - nuclei_true(nuclei_true > 0), # indexing just non-background pixels. |
354 |
| - nuclei_pred(nuclei_true > 0), |
355 |
| - num_classes=num_classes, |
356 |
| - task="multi_class", |
357 |
| - ) |
358 |
| - return confusion_matrix_per_cell |
359 |
| - |
360 |
| - |
361 |
| -# These images can be logged with prediction. |
362 |
| -def image_class_to_nuclei_class( |
363 |
| - y_true: torch.Tonser, y_pred: torch.Tensor, num_classes: int |
364 |
| -): |
365 |
| - """Convert the class of the image to the class of the nuclei. |
366 |
| -
|
367 |
| - Args: |
368 |
| - label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. |
369 |
| - num_classes (int): Number of classes. |
370 |
| -
|
371 |
| - Returns: |
372 |
| - torch.Tensor: Label images with a consensus class at the centroid of nuclei. |
373 |
| - """ |
374 |
| - nuclei_true = torch.zeros_like(y_true) |
375 |
| - nuclie_pred = torch.zeros_like(y_pred) |
376 |
| - batch_size = y_true.size(0) |
377 |
| - # find centroids of nuclei from y_true |
378 |
| - for i in range(batch_size): |
379 |
| - regions = regionprops(y_true[i].cpu().numpy()) |
380 |
| - # Find centroids, pixel coordinates from the ground truth. |
381 |
| - for region in regions: |
382 |
| - centroid = region.centroid |
383 |
| - pixel_ids = region.coords |
384 |
| - # Find the class of the nuclei in the ground truth and prediction. |
385 |
| - pix_labels_true = y_true[i, pixel_ids[:, 0], pixel_ids[:, 1]] |
386 |
| - consensus_class_true = np.mode(pix_labels_true[:]) |
387 |
| - |
388 |
| - pix_labels_pred = y_pred[i, pixel_ids[:, 0], pixel_ids[:, 1]] |
389 |
| - consensus_class_pred = np.mode(pix_labels_pred[:]) |
390 |
| - nuclei_true[i, centroid[0], centroid[1]] = consensus_class_true |
391 |
| - nuclei_pred[i, centroid[0], centroid[1]] = consensus_class_pred |
392 |
| - |
393 |
| - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. |
394 |
| - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. |
395 |
| - |
396 |
| - return nuclei_true, nuclei_pred |
397 |
| - |
398 |
| - |
399 |
| -def plot_confusion_matrix(confusion_matrix, index_to_label_dict): |
400 |
| - # Create a figure and axis to plot the confusion matrix |
401 |
| - fig, ax = plt.subplots() |
402 |
| - |
403 |
| - # Create a color heatmap for the confusion matrix |
404 |
| - cax = ax.matshow(confusion_matrix, cmap="viridis") |
405 |
| - |
406 |
| - # Create a colorbar and set the label |
407 |
| - fig.colorbar(cax, label="Frequency") |
408 |
| - |
409 |
| - # Set labels for the classes |
410 |
| - |
411 |
| - ax.set_xticks(np.arange(len(index_to_label_dict))) |
412 |
| - ax.set_yticks(np.arange(len(index_to_label_dict))) |
413 |
| - ax.set_xticklabels(index_to_label_dict.values(), rotation=45) |
414 |
| - ax.set_yticklabels(index_to_label_dict.values()) |
415 |
| - |
416 |
| - # Set labels for the axes |
417 |
| - ax.set_xlabel("Predicted") |
418 |
| - ax.set_ylabel("True") |
419 |
| - |
420 |
| - # Add text annotations to the confusion matrix |
421 |
| - for i in range(len(index_to_label_dict)): |
422 |
| - for j in range(len(index_to_label_dict)): |
423 |
| - ax.text( |
424 |
| - j, |
425 |
| - i, |
426 |
| - str(int(confusion_matrix[i, j])), |
427 |
| - ha="center", |
428 |
| - va="center", |
429 |
| - color="white", |
430 |
| - ) |
431 |
| - |
432 |
| - return fig |
0 commit comments