Skip to content

Commit 6bb9ca3

Browse files
separated training and test scripts
1 parent 708a67a commit 6bb9ca3

File tree

2 files changed

+29
-347
lines changed

2 files changed

+29
-347
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,15 @@
11
# %%
22
import torch
3-
from viscy.data.hcs import HCSDataModule
4-
5-
import numpy as np
6-
import torch.nn as nn
73
import lightning.pytorch as pl
8-
import torch.nn.functional as F
9-
import torchmetrics
10-
11-
# import torchview
12-
from typing import Literal, Sequence
13-
from skimage.exposure import rescale_intensity
14-
from sklearn.metrics import ConfusionMatrixDisplay
15-
from matplotlib.cm import get_cmap
4+
import torch.nn as nn
165

17-
# import napari
186
from pytorch_lightning.loggers import TensorBoardLogger
19-
from torch import Tensor
207
from pytorch_lightning.callbacks import ModelCheckpoint
218

22-
# from monai.losses import DiceLoss
23-
from monai.transforms import DivisiblePad
24-
from skimage.measure import regionprops
25-
26-
# from viscy.light.engine import VSUNet
27-
from viscy.unet.networks.Unet2D import Unet2d
28-
from viscy.data.hcs import Sample
29-
from viscy.transforms import RandWeightedCropd, RandGaussianNoised
9+
from viscy.transforms import RandWeightedCropd
3010
from viscy.transforms import NormalizeSampled
11+
from viscy.data.hcs import HCSDataModule
12+
from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D
3113

3214
# %% Create a dataloader and visualize the batches.
3315

@@ -91,202 +73,6 @@
9173
# # Start the napari event loop
9274
# napari.run()
9375

94-
# %%
95-
96-
# Define a 2D UNet model for semantic segmentation as a lightning module.
97-
98-
99-
class SemanticSegUNet2D(pl.LightningModule):
100-
# Model for semantic segmentation.
101-
def __init__(
102-
self,
103-
in_channels: int, # Number of input channels
104-
out_channels: int, # Number of output channels
105-
lr: float = 1e-3, # Learning rate
106-
loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function
107-
schedule: Literal[
108-
"WarmupCosine", "Constant"
109-
] = "Constant", # Learning rate schedule
110-
log_batches_per_epoch: int = 2, # Number of batches to log per epoch
111-
log_samples_per_batch: int = 2, # Number of samples to log per batch
112-
checkpoint_path: str = None, # Path to the checkpoint
113-
):
114-
super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer
115-
# Initialize the UNet model
116-
self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels)
117-
self.lr = lr # Set the learning rate
118-
# Set the loss function to CrossEntropyLoss if none is provided
119-
self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss()
120-
self.schedule = schedule # Set the learning rate schedule
121-
self.log_batches_per_epoch = (
122-
log_batches_per_epoch # Set the number of batches to log per epoch
123-
)
124-
self.log_samples_per_batch = (
125-
log_samples_per_batch # Set the number of samples to log per batch
126-
)
127-
self.training_step_outputs = [] # Initialize the list of training step outputs
128-
self.validation_step_outputs = (
129-
[]
130-
) # Initialize the list of validation step outputs
131-
132-
self.pred_cm = None # Initialize the confusion matrix
133-
self.index_to_label_dict = ["Background", "Infected", "Uninfected"]
134-
135-
if checkpoint_path is not None:
136-
state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[
137-
"state_dict"
138-
]
139-
state_dict.pop("loss_function.weight", None) # Remove the unexpected key
140-
self.load_state_dict(state_dict) # loading only weights
141-
142-
# Define the forward pass
143-
def forward(self, x):
144-
return self.unet_model(x) # Pass the input through the UNet model
145-
146-
# Define the optimizer
147-
def configure_optimizers(self):
148-
optimizer = torch.optim.Adam(
149-
self.parameters(), lr=self.lr
150-
) # Use the Adam optimizer
151-
return optimizer
152-
153-
# Define the training step
154-
def training_step(self, batch: Sample, batch_idx: int):
155-
source = batch["source"] # Extract the source from the batch
156-
target = batch["target"] # Extract the target from the batch
157-
pred = self.forward(source) # Make a prediction using the source
158-
# Convert the target to one-hot encoding
159-
target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(
160-
0, 4, 1, 2, 3
161-
)
162-
target_one_hot = target_one_hot.float() # Convert the target to float type
163-
train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss
164-
# Log the training step outputs if the batch index is less than the number of batches to log per epoch
165-
if batch_idx < self.log_batches_per_epoch:
166-
self.training_step_outputs.extend(
167-
self._detach_sample((source, target_one_hot, pred))
168-
)
169-
# Log the training loss
170-
self.log(
171-
"loss/train",
172-
train_loss,
173-
on_step=True,
174-
on_epoch=True,
175-
prog_bar=True,
176-
logger=True,
177-
sync_dist=True,
178-
)
179-
return train_loss # Return the training loss
180-
181-
def validation_step(self, batch: Sample, batch_idx: int):
182-
source = batch["source"] # Extract the source from the batch
183-
target = batch["target"] # Extract the target from the batch
184-
pred = self.forward(source) # Make a prediction using the source
185-
# Convert the target to one-hot encoding
186-
target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(
187-
0, 4, 1, 2, 3
188-
)
189-
target_one_hot = target_one_hot.float() # Convert the target to float type
190-
loss = self.loss_function(pred, target_one_hot) # Calculate the loss
191-
# Log the validation step outputs if the batch index is less than the number of batches to log per epoch
192-
if batch_idx < self.log_batches_per_epoch:
193-
self.validation_step_outputs.extend(
194-
self._detach_sample((source, target_one_hot, pred))
195-
)
196-
# Log the validation loss
197-
self.log(
198-
"loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True
199-
)
200-
return loss # Return the validation loss
201-
202-
def on_predict_start(self):
203-
"""Pad the input shape to be divisible by the downsampling factor.
204-
The inverse of this transform crops the prediction to original shape.
205-
"""
206-
down_factor = 2**self.unet_model.num_blocks
207-
self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor))
208-
209-
# Define the prediction step
210-
def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0):
211-
source = self._predict_pad(batch["source"]) # Pad the source
212-
target = batch["target"] # Extract the target from the batch
213-
logits = self._predict_pad.inverse(
214-
self.forward(source)
215-
) # Predict and remove padding.
216-
prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities
217-
# Go from probabilities/one-hot encoded data to class labels.
218-
labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels
219-
labels_target = torch.argmax(target, dim=1) # Calculate the target labels
220-
# FIXME: Check if compliant with lightning API
221-
self.pred_cm = confusion_matrix_per_cell(
222-
labels_target, labels_pred, num_classes=3
223-
)
224-
225-
return prob_pred # log the probabilities instead of logits.
226-
227-
# Accumulate the confusion matrix at the end of prediction epoch and log.
228-
def on_predict_epoch_end(self):
229-
confusion_matrix = self.pred_cm.compute().cpu().numpy()
230-
self.logger.experiment.add_figure(
231-
"Confusion Matrix per Cell",
232-
plot_confusion_matrix(confusion_matrix, self.index_to_label_dict),
233-
self.current_epoch,
234-
)
235-
236-
# Define what happens at the end of a training epoch
237-
def on_train_epoch_end(self):
238-
self._log_samples(
239-
"train_samples", self.training_step_outputs
240-
) # Log the training samples
241-
self.training_step_outputs = [] # Reset the list of training step outputs
242-
243-
# Define what happens at the end of a validation epoch
244-
def on_validation_epoch_end(self):
245-
self._log_samples(
246-
"val_samples", self.validation_step_outputs
247-
) # Log the validation samples
248-
self.validation_step_outputs = [] # Reset the list of validation step outputs
249-
# TODO: Log the confusion matrix
250-
251-
# Define a method to detach a sample
252-
def _detach_sample(self, imgs: Sequence[Tensor]):
253-
# Detach the images and convert them to numpy arrays
254-
num_samples = 3
255-
return [
256-
[np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs]
257-
for i in range(num_samples)
258-
]
259-
260-
# Define a method to log samples
261-
def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
262-
images_grid = [] # Initialize the list of image grids
263-
for sample_images in imgs: # For each sample image
264-
images_row = [] # Initialize the list of image rows
265-
for i, image in enumerate(
266-
sample_images
267-
): # For each image in the sample images
268-
cm_name = "gray" if i == 0 else "inferno" # Set the colormap name
269-
if image.ndim == 2: # If the image is 2D
270-
image = image[np.newaxis] # Add a new axis
271-
for channel in image: # For each channel in the image
272-
channel = rescale_intensity(
273-
channel, out_range=(0, 1)
274-
) # Rescale the intensity of the channel
275-
render = get_cmap(cm_name)(channel, bytes=True)[
276-
..., :3
277-
] # Render the channel
278-
images_row.append(
279-
render
280-
) # Append the render to the list of image rows
281-
images_grid.append(
282-
np.concatenate(images_row, axis=1)
283-
) # Append the concatenated image rows to the list of image grids
284-
grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids
285-
# Log the image grid
286-
self.logger.experiment.add_image(
287-
key, grid, self.current_epoch, dataformats="HWC"
288-
)
289-
29076

29177
# %% Define the logger
29278
logger = TensorBoardLogger(
@@ -328,105 +114,3 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
328114
# Run training.
329115

330116
trainer.fit(model, data_module)
331-
332-
# %% Methods to compute confusion matrix per cell using torchmetrics
333-
334-
335-
# The confusion matrix at the single-cell resolution.
336-
def confusion_matrix_per_cell(
337-
y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int
338-
):
339-
"""Compute confusion matrix per cell.
340-
341-
Args:
342-
y_true (torch.Tensor): Ground truth label image (BXHXW).
343-
y_pred (torch.Tensor): Predicted label image (BXHXW).
344-
num_classes (int): Number of classes.
345-
346-
Returns:
347-
torch.Tensor: Confusion matrix per cell (BXCXC).
348-
"""
349-
# Convert the image class to the nuclei class
350-
nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes)
351-
# Compute the confusion matrix per cell
352-
confusion_matrix_per_cell = torchmetrics.functional.confusion_matrix(
353-
nuclei_true(nuclei_true > 0), # indexing just non-background pixels.
354-
nuclei_pred(nuclei_true > 0),
355-
num_classes=num_classes,
356-
task="multi_class",
357-
)
358-
return confusion_matrix_per_cell
359-
360-
361-
# These images can be logged with prediction.
362-
def image_class_to_nuclei_class(
363-
y_true: torch.Tonser, y_pred: torch.Tensor, num_classes: int
364-
):
365-
"""Convert the class of the image to the class of the nuclei.
366-
367-
Args:
368-
label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation.
369-
num_classes (int): Number of classes.
370-
371-
Returns:
372-
torch.Tensor: Label images with a consensus class at the centroid of nuclei.
373-
"""
374-
nuclei_true = torch.zeros_like(y_true)
375-
nuclie_pred = torch.zeros_like(y_pred)
376-
batch_size = y_true.size(0)
377-
# find centroids of nuclei from y_true
378-
for i in range(batch_size):
379-
regions = regionprops(y_true[i].cpu().numpy())
380-
# Find centroids, pixel coordinates from the ground truth.
381-
for region in regions:
382-
centroid = region.centroid
383-
pixel_ids = region.coords
384-
# Find the class of the nuclei in the ground truth and prediction.
385-
pix_labels_true = y_true[i, pixel_ids[:, 0], pixel_ids[:, 1]]
386-
consensus_class_true = np.mode(pix_labels_true[:])
387-
388-
pix_labels_pred = y_pred[i, pixel_ids[:, 0], pixel_ids[:, 1]]
389-
consensus_class_pred = np.mode(pix_labels_pred[:])
390-
nuclei_true[i, centroid[0], centroid[1]] = consensus_class_true
391-
nuclei_pred[i, centroid[0], centroid[1]] = consensus_class_pred
392-
393-
# Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction.
394-
# Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction.
395-
396-
return nuclei_true, nuclei_pred
397-
398-
399-
def plot_confusion_matrix(confusion_matrix, index_to_label_dict):
400-
# Create a figure and axis to plot the confusion matrix
401-
fig, ax = plt.subplots()
402-
403-
# Create a color heatmap for the confusion matrix
404-
cax = ax.matshow(confusion_matrix, cmap="viridis")
405-
406-
# Create a colorbar and set the label
407-
fig.colorbar(cax, label="Frequency")
408-
409-
# Set labels for the classes
410-
411-
ax.set_xticks(np.arange(len(index_to_label_dict)))
412-
ax.set_yticks(np.arange(len(index_to_label_dict)))
413-
ax.set_xticklabels(index_to_label_dict.values(), rotation=45)
414-
ax.set_yticklabels(index_to_label_dict.values())
415-
416-
# Set labels for the axes
417-
ax.set_xlabel("Predicted")
418-
ax.set_ylabel("True")
419-
420-
# Add text annotations to the confusion matrix
421-
for i in range(len(index_to_label_dict)):
422-
for j in range(len(index_to_label_dict)):
423-
ax.text(
424-
j,
425-
i,
426-
str(int(confusion_matrix[i, j])),
427-
ha="center",
428-
va="center",
429-
color="white",
430-
)
431-
432-
return fig

0 commit comments

Comments
 (0)