-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into refactor/native-pipes
- Loading branch information
Showing
9 changed files
with
479 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,9 @@ wheels/ | |
models | ||
data | ||
logs | ||
wandb | ||
lightning_logs | ||
artifacts | ||
|
||
# Mkdocs | ||
.cache | ||
|
86 changes: 86 additions & 0 deletions
86
darts-segmentation/src/darts_segmentation/training/data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# ruff: noqa: D101 | ||
# ruff: noqa: D102 | ||
# ruff: noqa: D105 | ||
# ruff: noqa: D107 | ||
"""Training script for DARTS segmentation.""" | ||
|
||
import logging | ||
from pathlib import Path | ||
from typing import Literal | ||
|
||
import albumentations as A # noqa: N812 | ||
import lightning as L # noqa: N812 | ||
import torch | ||
from torch.utils.data import DataLoader, Dataset, random_split | ||
|
||
logger = logging.getLogger(__name__.replace("darts_", "darts.")) | ||
|
||
|
||
class DartsDataset(Dataset): | ||
def __init__(self, data_dir: Path, augment: bool): | ||
self.x_files = sorted((data_dir / "x").glob("*.pt")) | ||
self.y_files = sorted((data_dir / "y").glob("*.pt")) | ||
|
||
assert len(self.x_files) == len( | ||
self.y_files | ||
), f"Dataset corrupted! Got {len(self.x_files)=} and {len(self.y_files)=}!" | ||
|
||
self.transform = ( | ||
A.Compose( | ||
[ | ||
A.HorizontalFlip(), | ||
A.VerticalFlip(), | ||
A.RandomRotate90(), | ||
# A.Blur(), | ||
A.RandomBrightnessContrast(), | ||
A.MultiplicativeNoise(per_channel=True, elementwise=True), | ||
# ToTensorV2(), | ||
] | ||
) | ||
if augment | ||
else None | ||
) | ||
|
||
def __len__(self): | ||
return len(self.x_files) | ||
|
||
def __getitem__(self, idx): | ||
xfile = self.x_files[idx] | ||
yfile = self.y_files[idx] | ||
assert xfile.stem == yfile.stem, f"Dataset corrupted! Files must have the same name, but got {xfile=} {yfile=}!" | ||
|
||
x = torch.load(xfile).numpy() | ||
y = torch.load(yfile).int().numpy() | ||
|
||
# Apply augmentations | ||
if self.transform is not None: | ||
augmented = self.transform(image=x.transpose(1, 2, 0), mask=y) | ||
x = augmented["image"].transpose(2, 0, 1) | ||
y = augmented["mask"] | ||
|
||
return x, y | ||
|
||
|
||
class DartsDataModule(L.LightningDataModule): | ||
def __init__(self, data_dir: Path, batch_size: int, augment: bool = True, num_workers: int = 0): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.data_dir = data_dir | ||
self.batch_size = batch_size | ||
self.augment = augment | ||
self.num_workers = num_workers | ||
|
||
def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None): | ||
dataset = DartsDataset(self.data_dir, self.augment) | ||
splits = [0.8, 0.1, 0.1] | ||
generator = torch.Generator().manual_seed(42) | ||
self.train, self.val, self.test = random_split(dataset, splits, generator) | ||
|
||
def train_dataloader(self): | ||
return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True) | ||
|
||
def val_dataloader(self): | ||
return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers) | ||
|
||
def test_dataloader(self): | ||
return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers) |
146 changes: 146 additions & 0 deletions
146
darts-segmentation/src/darts_segmentation/training/module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
# ruff: noqa: D100 | ||
# ruff: noqa: D101 | ||
# ruff: noqa: D102 | ||
# ruff: noqa: D105 | ||
# ruff: noqa: D107 | ||
|
||
"""Training script for DARTS segmentation.""" | ||
|
||
from pathlib import Path | ||
|
||
import lightning as L # noqa: N812 | ||
import segmentation_models_pytorch as smp | ||
import torch.optim as optim | ||
import wandb | ||
from lightning.pytorch.loggers import CSVLogger, WandbLogger | ||
from torchmetrics import ( | ||
AUROC, | ||
ROC, | ||
Accuracy, | ||
AveragePrecision, | ||
CohenKappa, | ||
ConfusionMatrix, | ||
F1Score, | ||
HammingDistance, | ||
JaccardIndex, | ||
MetricCollection, | ||
Precision, | ||
PrecisionRecallCurve, | ||
Recall, | ||
Specificity, | ||
) | ||
from wandb.sdk.wandb_run import Run | ||
|
||
from darts_segmentation.segment import SMPSegmenterConfig | ||
from darts_segmentation.training.viz import plot_sample | ||
|
||
|
||
class SMPSegmenter(L.LightningModule): | ||
def __init__(self, config: SMPSegmenterConfig, learning_rate: float = 1e-5, gamma: float = 0.9): | ||
super().__init__() | ||
self.save_hyperparameters() | ||
self.model = smp.create_model(**config["model"], activation="sigmoid") | ||
|
||
self.loss_fn = smp.losses.FocalLoss(mode="binary") | ||
|
||
metrics = MetricCollection( | ||
{ | ||
"Accuracy": Accuracy(task="binary", validate_args=False), | ||
"Precision": Precision(task="binary", validate_args=False), | ||
"Specificity": Specificity(task="binary", validate_args=False), | ||
"Recall": Recall(task="binary", validate_args=False), | ||
"F1Score": F1Score(task="binary", validate_args=False), | ||
"JaccardIndex": JaccardIndex(task="binary", validate_args=False), | ||
"CohenKappa": CohenKappa(task="binary", validate_args=False), | ||
"HammingDistance": HammingDistance(task="binary", validate_args=False), | ||
} | ||
) | ||
self.train_metrics = metrics.clone(prefix="train/") | ||
self.val_metrics = metrics.clone(prefix="val/") | ||
self.val_metrics.add_metrics( | ||
{ | ||
"AUROC": AUROC(task="binary", thresholds=20, validate_args=False), | ||
"AveragePrecision": AveragePrecision(task="binary", thresholds=20, validate_args=False), | ||
} | ||
) | ||
self.val_roc = ROC(task="binary", thresholds=20, validate_args=False) | ||
self.val_prc = PrecisionRecallCurve(task="binary", thresholds=20, validate_args=False) | ||
self.val_cmx = ConfusionMatrix(task="binary", normalize="true", validate_args=False) | ||
|
||
def training_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self.model(x).squeeze(1) | ||
loss = self.loss_fn(y_hat, y.long()) | ||
self.train_metrics(y_hat, y) | ||
self.log("train/loss", loss) | ||
self.log_dict(self.train_metrics, on_step=True, on_epoch=False) | ||
return loss | ||
|
||
def on_train_epoch_end(self): | ||
self.train_metrics.reset() | ||
|
||
def validation_step(self, batch, batch_idx): | ||
x, y = batch | ||
y_hat = self.model(x).squeeze(1) | ||
loss = self.loss_fn(y_hat, y.long()) | ||
self.log("val/loss", loss) | ||
|
||
self.val_metrics.update(y_hat, y) | ||
self.val_roc.update(y_hat, y) | ||
self.val_prc.update(y_hat, y) | ||
self.val_cmx.update(y_hat, y) | ||
|
||
# Create figures for the samples | ||
for i in range(x.shape[0]): | ||
fig, _ = plot_sample(x[i], y[i], y_hat[i], self.hparams.config["input_combination"]) | ||
for logger in self.loggers: | ||
if isinstance(logger, CSVLogger): | ||
fig_dir = Path(logger.log_dir) / "figures" | ||
fig_dir.mkdir(exist_ok=True) | ||
fig.savefig(fig_dir / f"sample_{self.global_step}_{batch_idx}_{i}.png") | ||
if isinstance(logger, WandbLogger): | ||
wandb_run: Run = logger.experiment | ||
wandb_run.log({f"val/sample_{batch_idx}_{i}": wandb.Image(fig)}, step=self.global_step) | ||
fig.clear() | ||
|
||
return loss | ||
|
||
def on_validation_epoch_end(self): | ||
self.log_dict(self.val_metrics.compute()) | ||
|
||
self.val_cmx.compute() | ||
self.val_roc.compute() | ||
self.val_prc.compute() | ||
|
||
# Plot roc, prc and confusion matrix to disk and wandb | ||
fig_cmx, _ = self.val_cmx.plot(cmap="Blues") | ||
fig_roc, _ = self.val_roc.plot(score=True) | ||
fig_prc, _ = self.val_prc.plot(score=True) | ||
|
||
# Check for a wandb or csv logger to log the images | ||
for logger in self.loggers: | ||
if isinstance(logger, CSVLogger): | ||
fig_dir = Path(logger.log_dir) / "figures" | ||
fig_dir.mkdir(exist_ok=True) | ||
fig_cmx.savefig(fig_dir / f"cmx_{self.global_step}png") | ||
fig_roc.savefig(fig_dir / f"roc_{self.global_step}png") | ||
fig_prc.savefig(fig_dir / f"prc_{self.global_step}.png") | ||
if isinstance(logger, WandbLogger): | ||
wandb_run: Run = logger.experiment | ||
wandb_run.log({"val/cmx": wandb.Image(fig_cmx)}, step=self.global_step) | ||
wandb_run.log({"val/roc": wandb.Image(fig_roc)}, step=self.global_step) | ||
wandb_run.log({"val/prc": wandb.Image(fig_prc)}, step=self.global_step) | ||
|
||
fig_cmx.clear() | ||
fig_roc.clear() | ||
fig_prc.clear() | ||
|
||
self.val_metrics.reset() | ||
self.val_roc.reset() | ||
self.val_prc.reset() | ||
self.val_cmx.reset() | ||
|
||
def configure_optimizers(self): | ||
optimizer = optim.AdamW(self.parameters(), lr=self.hparams.learning_rate) | ||
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.hparams.gamma) | ||
return [optimizer], [scheduler] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
"""Visualization utilities for the training module.""" | ||
|
||
import matplotlib.colors as mcolors | ||
import matplotlib.patches as mpatches | ||
import matplotlib.pyplot as plt | ||
import torch | ||
|
||
|
||
def plot_sample(x, y, y_pred, input_combinations: list[str]): | ||
"""Plot a single sample with the input, the ground truth and the prediction. | ||
Args: | ||
x (torch.Tensor): The input tensor [C, H, W] (float). | ||
y (torch.Tensor): The ground truth tensor [H, W] (int). | ||
y_pred (torch.Tensor): The prediction tensor [H, W] (float). | ||
input_combinations (list[str]): The combinations of the input bands. | ||
Returns: | ||
tuple[Figure, dict[str, Axes]]: The figure and the axes of the plot. | ||
""" | ||
x = x.cpu() | ||
y = y.cpu() | ||
y_pred = y_pred.detach().cpu() | ||
|
||
classification_labels = (y_pred > 0.5).int() + y * 2 | ||
classification_labels = classification_labels.where(classification_labels != 0, torch.nan) | ||
|
||
# Calculate accuracy and iou | ||
true_positive = (classification_labels == 3).sum() | ||
false_positive = (classification_labels == 1).sum() | ||
false_negative = (classification_labels == 2).sum() | ||
acc = true_positive / (true_positive + false_positive + false_negative) | ||
|
||
cmap = mcolors.ListedColormap(["#cd43b2", "#3e0f2f", "#6cd875"]) | ||
fig, axs = plt.subplot_mosaic([["a", "a", "b", "c"], ["a", "a", "d", "e"]], layout="constrained", figsize=(16, 8)) | ||
|
||
# RGB Plot | ||
red_band = input_combinations.index("red") | ||
green_band = input_combinations.index("green") | ||
blue_band = input_combinations.index("blue") | ||
rgb = x[[red_band, green_band, blue_band]].transpose(0, 2).transpose(0, 1) | ||
ax_rgb = axs["a"] | ||
ax_rgb.imshow(rgb ** (1 / 1.4)) | ||
ax_rgb.imshow(classification_labels, alpha=0.6, cmap=cmap, vmin=1, vmax=3) | ||
# Add a legend | ||
patches = [ | ||
mpatches.Patch(color="#6cd875", label="True Positive"), | ||
mpatches.Patch(color="#3e0f2f", label="False Negative"), | ||
mpatches.Patch(color="#cd43b2", label="False Positive"), | ||
] | ||
ax_rgb.legend(handles=patches, loc="upper left") | ||
# disable axis | ||
ax_rgb.axis("off") | ||
ax_rgb.set_title(f"Accuracy: {acc:.1%}") | ||
|
||
# NIR Plot | ||
nir_band = input_combinations.index("nir") | ||
nir = x[nir_band] | ||
ax_nir = axs["b"] | ||
ax_nir.imshow(nir, vmin=0, vmax=1) | ||
ax_nir.axis("off") | ||
ax_nir.set_title("NIR") | ||
|
||
# TCVIS Plot | ||
tcb_band = input_combinations.index("tc_brightness") | ||
tcg_band = input_combinations.index("tc_greenness") | ||
tcw_band = input_combinations.index("tc_wetness") | ||
tcvis = x[[tcb_band, tcg_band, tcw_band]].transpose(0, 2).transpose(0, 1) | ||
ax_tcv = axs["c"] | ||
ax_tcv.imshow(tcvis) | ||
ax_tcv.axis("off") | ||
ax_tcv.set_title("TCVIS") | ||
|
||
# NDVI Plot | ||
ndvi_band = input_combinations.index("ndvi") | ||
ndvi = x[ndvi_band] | ||
ax_ndvi = axs["d"] | ||
ax_ndvi.imshow(ndvi, vmin=0, vmax=1) | ||
ax_ndvi.axis("off") | ||
ax_ndvi.set_title("NDVI") | ||
|
||
# Prediction Plot | ||
ax_mask = axs["e"] | ||
ax_mask.imshow(y_pred, vmin=0, vmax=1) | ||
ax_mask.axis("off") | ||
ax_mask.set_title("Prediction") | ||
|
||
return fig, axs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.