From 5209f4704da616ef6fe97dd06dc7c88342adc6e4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Feb 2021 23:49:57 +0100 Subject: [PATCH] fix --- pytorch_lightning/accelerators/accelerator.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index d33c7506e165f..d568fd525b25f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -218,23 +218,6 @@ def validation_step_end(self, output): """ return self.training_type_plugin.validation_step_end(output) - def predict(self, args): - """The prediction step. - - Args: - args: the arguments for the models predict step. Can consist of the following: - batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): - The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. - batch_idx (int): Integer displaying index of this batch - optimizer_idx (int): When using multiple optimizers, this argument will also be present. - hiddens(:class:`~torch.Tensor`): Passed in if - :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. - - """ - batch = self.to_device(args[0]) - args[0] = batch - return self.training_type_plugin.predict(*args) - def backward( self, closure_loss: torch.Tensor,