diff --git a/pl_bolts/callbacks/self_supervised.py b/pl_bolts/callbacks/self_supervised.py index 4ff7216a18..9fb7bf57f1 100644 --- a/pl_bolts/callbacks/self_supervised.py +++ b/pl_bolts/callbacks/self_supervised.py @@ -72,7 +72,7 @@ def to_device(self, batch, device): return x, y - def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): x, y = self.to_device(batch, pl_module.device) with torch.no_grad(): @@ -131,7 +131,7 @@ def __init__(self, initial_tau=0.996): self.initial_tau = initial_tau self.current_tau = initial_tau - def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): # get networks online_net = pl_module.online_network target_net = pl_module.target_network diff --git a/pl_bolts/callbacks/vision/confused_logit.py b/pl_bolts/callbacks/vision/confused_logit.py index 3e6f92da50..2af5d0d72a 100644 --- a/pl_bolts/callbacks/vision/confused_logit.py +++ b/pl_bolts/callbacks/vision/confused_logit.py @@ -55,7 +55,7 @@ def __init__( self.logging_batch_interval = logging_batch_interval self.min_logit_value = min_logit_value - def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): # show images only every 20 batches if (trainer.batch_idx + 1) % self.logging_batch_interval != 0: return diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 95c68bbee7..b983988ea7 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -97,9 +97,9 @@ def __init__(self, self.target_network = deepcopy(self.online_network) self.weight_callback = BYOLMAWeightUpdate() - def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: # Add callback for user automatically since it's key to BYOL weight update - self.weight_callback.on_train_batch_end(self.trainer, self, batch, batch_idx, dataloader_idx) + self.weight_callback.on_train_batch_end(self.trainer, self, outputs, batch, batch_idx, dataloader_idx) def forward(self, x): y, _, _ = self.online_network(x)