@@ -150,6 +150,7 @@ def __init__(
150
150
self .log_batches_per_epoch = log_batches_per_epoch
151
151
self .log_samples_per_batch = log_samples_per_batch
152
152
self .training_step_outputs = []
153
+ self .validation_losses = []
153
154
self .validation_step_outputs = []
154
155
# required to log the graph
155
156
if architecture == "2D" :
@@ -175,31 +176,46 @@ def forward(self, x: Tensor) -> Tensor:
175
176
return self .model (x )
176
177
177
178
def training_step (self , batch : Sample , batch_idx : int ):
178
- source = batch ["source" ]
179
- target = batch ["target" ]
180
- pred = self .forward (source )
181
- loss = self .loss_function (pred , target )
179
+ losses = []
180
+ batch_size = 0
181
+ for b in batch :
182
+ source = b ["source" ]
183
+ target = b ["target" ]
184
+ pred = self .forward (source )
185
+ loss = self .loss_function (pred , target )
186
+ losses .append (loss )
187
+ batch_size += source .shape [0 ]
188
+ if batch_idx < self .log_batches_per_epoch :
189
+ self .training_step_outputs .extend (
190
+ self ._detach_sample ((source , target , pred ))
191
+ )
192
+ loss_step = torch .stack (losses ).mean ()
182
193
self .log (
183
194
"loss/train" ,
184
- loss ,
195
+ loss_step . to ( self . device ) ,
185
196
on_step = True ,
186
197
on_epoch = True ,
187
198
prog_bar = True ,
188
199
logger = True ,
189
200
sync_dist = True ,
201
+ batch_size = batch_size ,
190
202
)
191
- if batch_idx < self .log_batches_per_epoch :
192
- self .training_step_outputs .extend (
193
- self ._detach_sample ((source , target , pred ))
194
- )
195
- return loss
203
+ return loss_step
196
204
197
205
def validation_step (self , batch : Sample , batch_idx : int , dataloader_idx : int = 0 ):
198
- source = batch ["source" ]
199
- target = batch ["target" ]
206
+ source : Tensor = batch ["source" ]
207
+ target : Tensor = batch ["target" ]
200
208
pred = self .forward (source )
201
209
loss = self .loss_function (pred , target )
202
- self .log ("loss/validate" , loss , sync_dist = True , add_dataloader_idx = False )
210
+ if dataloader_idx + 1 > len (self .validation_losses ):
211
+ self .validation_losses .append ([])
212
+ self .validation_losses [dataloader_idx ].append (loss .detach ())
213
+ self .log (
214
+ f"loss/val/{ dataloader_idx } " ,
215
+ loss .to (self .device ),
216
+ sync_dist = True ,
217
+ batch_size = source .shape [0 ],
218
+ )
203
219
if batch_idx < self .log_batches_per_epoch :
204
220
self .validation_step_outputs .extend (
205
221
self ._detach_sample ((source , target , pred ))
@@ -309,8 +325,16 @@ def on_train_epoch_end(self):
309
325
self .training_step_outputs = []
310
326
311
327
def on_validation_epoch_end (self ):
328
+ super ().on_validation_epoch_end ()
312
329
self ._log_samples ("val_samples" , self .validation_step_outputs )
313
330
self .validation_step_outputs = []
331
+ # average within each dataloader
332
+ loss_means = [torch .tensor (losses ).mean () for losses in self .validation_losses ]
333
+ self .log (
334
+ "loss/validate" ,
335
+ torch .tensor (loss_means ).mean ().to (self .device ),
336
+ sync_dist = True ,
337
+ )
314
338
315
339
def on_test_start (self ):
316
340
"""Load CellPose model for segmentation."""
@@ -386,7 +410,6 @@ class FcmaeUNet(VSUNet):
386
410
def __init__ (self , fit_mask_ratio : float = 0.0 , ** kwargs ):
387
411
super ().__init__ (architecture = "fcmae" , ** kwargs )
388
412
self .fit_mask_ratio = fit_mask_ratio
389
- self .validation_losses = []
390
413
391
414
def forward (self , x : Tensor , mask_ratio : float = 0.0 ):
392
415
return self .model (x , mask_ratio )
@@ -438,13 +461,3 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0
438
461
self .validation_step_outputs .extend (
439
462
self ._detach_sample ((source , target * mask .unsqueeze (2 ), pred ))
440
463
)
441
-
442
- def on_validation_epoch_end (self ):
443
- super ().on_validation_epoch_end ()
444
- # average within each dataloader
445
- loss_means = [torch .tensor (losses ).mean () for losses in self .validation_losses ]
446
- self .log (
447
- "loss/validate" ,
448
- torch .tensor (loss_means ).mean ().to (self .device ),
449
- sync_dist = True ,
450
- )
0 commit comments