Skip to content

Commit bf1b9d3

Browse files
committed
fix combined loader training for virtual staining task
1 parent 31522ae commit bf1b9d3

File tree

1 file changed

+37
-24
lines changed

1 file changed

+37
-24
lines changed

viscy/light/engine.py

+37-24
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(
150150
self.log_batches_per_epoch = log_batches_per_epoch
151151
self.log_samples_per_batch = log_samples_per_batch
152152
self.training_step_outputs = []
153+
self.validation_losses = []
153154
self.validation_step_outputs = []
154155
# required to log the graph
155156
if architecture == "2D":
@@ -175,31 +176,46 @@ def forward(self, x: Tensor) -> Tensor:
175176
return self.model(x)
176177

177178
def training_step(self, batch: Sample, batch_idx: int):
178-
source = batch["source"]
179-
target = batch["target"]
180-
pred = self.forward(source)
181-
loss = self.loss_function(pred, target)
179+
losses = []
180+
batch_size = 0
181+
for b in batch:
182+
source = b["source"]
183+
target = b["target"]
184+
pred = self.forward(source)
185+
loss = self.loss_function(pred, target)
186+
losses.append(loss)
187+
batch_size += source.shape[0]
188+
if batch_idx < self.log_batches_per_epoch:
189+
self.training_step_outputs.extend(
190+
self._detach_sample((source, target, pred))
191+
)
192+
loss_step = torch.stack(losses).mean()
182193
self.log(
183194
"loss/train",
184-
loss,
195+
loss_step.to(self.device),
185196
on_step=True,
186197
on_epoch=True,
187198
prog_bar=True,
188199
logger=True,
189200
sync_dist=True,
201+
batch_size=batch_size,
190202
)
191-
if batch_idx < self.log_batches_per_epoch:
192-
self.training_step_outputs.extend(
193-
self._detach_sample((source, target, pred))
194-
)
195-
return loss
203+
return loss_step
196204

197205
def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0):
198-
source = batch["source"]
199-
target = batch["target"]
206+
source: Tensor = batch["source"]
207+
target: Tensor = batch["target"]
200208
pred = self.forward(source)
201209
loss = self.loss_function(pred, target)
202-
self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False)
210+
if dataloader_idx + 1 > len(self.validation_losses):
211+
self.validation_losses.append([])
212+
self.validation_losses[dataloader_idx].append(loss.detach())
213+
self.log(
214+
f"loss/val/{dataloader_idx}",
215+
loss.to(self.device),
216+
sync_dist=True,
217+
batch_size=source.shape[0],
218+
)
203219
if batch_idx < self.log_batches_per_epoch:
204220
self.validation_step_outputs.extend(
205221
self._detach_sample((source, target, pred))
@@ -309,8 +325,16 @@ def on_train_epoch_end(self):
309325
self.training_step_outputs = []
310326

311327
def on_validation_epoch_end(self):
328+
super().on_validation_epoch_end()
312329
self._log_samples("val_samples", self.validation_step_outputs)
313330
self.validation_step_outputs = []
331+
# average within each dataloader
332+
loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses]
333+
self.log(
334+
"loss/validate",
335+
torch.tensor(loss_means).mean().to(self.device),
336+
sync_dist=True,
337+
)
314338

315339
def on_test_start(self):
316340
"""Load CellPose model for segmentation."""
@@ -386,7 +410,6 @@ class FcmaeUNet(VSUNet):
386410
def __init__(self, fit_mask_ratio: float = 0.0, **kwargs):
387411
super().__init__(architecture="fcmae", **kwargs)
388412
self.fit_mask_ratio = fit_mask_ratio
389-
self.validation_losses = []
390413

391414
def forward(self, x: Tensor, mask_ratio: float = 0.0):
392415
return self.model(x, mask_ratio)
@@ -438,13 +461,3 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0
438461
self.validation_step_outputs.extend(
439462
self._detach_sample((source, target * mask.unsqueeze(2), pred))
440463
)
441-
442-
def on_validation_epoch_end(self):
443-
super().on_validation_epoch_end()
444-
# average within each dataloader
445-
loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses]
446-
self.log(
447-
"loss/validate",
448-
torch.tensor(loss_means).mean().to(self.device),
449-
sync_dist=True,
450-
)

0 commit comments

Comments
 (0)