diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 85d31be09a6c..56caa1c71c94 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2294,9 +2294,7 @@ def _inner_training_loop( tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) else: if tr_loss.device != tr_loss_step.device: - raise ValueError( - f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}" - ) + tr_loss_step = tr_loss_step.to(tr_loss.device) tr_loss += tr_loss_step self.current_flos += float(self.floating_point_ops(inputs))