Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent API for on_predict_epoch_end #8479

Closed
ananthsub opened this issue Jul 20, 2021 · 7 comments · Fixed by #16655
Closed

Inconsistent API for on_predict_epoch_end #8479

ananthsub opened this issue Jul 20, 2021 · 7 comments · Fixed by #16655
Assignees
Labels
design Includes a design discussion feature Is an improvement or enhancement hooks Related to the hooks API trainer: predict
Milestone

Comments

@ananthsub
Copy link
Contributor

ananthsub commented Jul 20, 2021

🚀 Feature

Background
We are auditing the Lightning components and APIs to assess opportunities for improvements:

One item that came up was on_predict_epoch_end() defined on the ModelHooks Mixin. This accepts an outputs: List[Any] argument. However, this is inconsistent with the other model hooks of the same type: on_train_epoch_end, on_validation_epoch_end, and on_test_epoch_end

Motivation

API consistency with other epoch end model hooks.

Pitch

  • Add a prediction_epoch_end hook to the LightningModule
  • Deprecate the outputs argument from on_predict_epoch_end in v1.5 and remove entirely in v1.7
  • Update the prediction loop to only cache predictions if prediction_epoch_end is implemented

Users can optionally avoid this entirely and cache their prediction outputs as needed by implementing this logic in predict_step directly

Alternatives

Keep as is?

Additional context

cc @Borda @tchaton @justusschock @awaelchli @carmocca @ninginthecloud @daniellepintz @rohitgr7

@ananthsub ananthsub added feature Is an improvement or enhancement help wanted Open to be worked on design Includes a design discussion labels Jul 20, 2021
@ananthsub ananthsub added this to the v1.5 milestone Jul 20, 2021
@awaelchli awaelchli modified the milestones: v1.5, v1.6 Nov 4, 2021
@carmocca carmocca added good first issue Good for newcomers hooks Related to the hooks API trainer: predict labels Feb 1, 2022
@carmocca carmocca modified the milestones: 1.6, future Feb 14, 2022
@ishtos
Copy link
Contributor

ishtos commented Apr 4, 2022

I'd like to work on this, could you assign this issue to me?

@mattcleigh
Copy link

mattcleigh commented Mar 3, 2023

So I just want to follow up on the standard way of now returning data using the predict step.
Ideally I would want this all defined in the model class so I dont have to tweak the prediction script for each model I use.

Would it be something like this, because it seems less intuitive to me compared to the older methods and adds alot of boilerplate code:

class Model(pl.LightningModule):
    def __init__(self, *args):
        super().__init__()
        self.save_hyperparameters(logger=False)
        self.predict_step_outputs = []
        self.predict_outputs = None
        ...

    def on_predict_epoch_start(self) -> None:
        self.predict_outputs = None

    def predict_step(self, batch: tuple, _batch_idx: int) -> None:
        outputs = ... # However to calculate the outputs
        self.predict_step_outputs.append(outputs)

    def on_predict_epoch_end(self, outputs) -> tuple:
        combined_predictions = ...  # However to combine, like T.vstack etc
        self.predict_outputs = combined_predictions
        self.predict_step_outputs.clear()

# Then in the export/predict script
trainer.predict(model=model, datamodule=datamodule)
outputs = model.predict_outputs

@carmocca
Copy link
Contributor

carmocca commented Mar 6, 2023

That works. You could also do the combination after trainer.predict() so that you don't need the predict_outputs attribute. It's up to you now to do it however you prefer!

@mattcleigh
Copy link

So there is no way to overload how the trainer.predict() method combines the output of each batch? Will it always return as a list?

@carmocca
Copy link
Contributor

carmocca commented Mar 6, 2023

Are you asking about the predictions returned? predictions = trainer.predict(..., return_predictions=True)?

@mattcleigh
Copy link

mattcleigh commented Mar 6, 2023

Exactly. By default it returns a list of whatever was returned by the model's predict_step for each batch pass. If I wanted to do some post processing on the entire collection of predictions, or some simple stacking so to get a single tensor out, I would have to do it afterwards. I was wondering if there was a built in function for this or would I have to define my own like:

predictions = trainer.predict(..., return_predictions=True)
predictions = model.process_and_combine(predictions)

@carmocca
Copy link
Contributor

carmocca commented Mar 7, 2023

You would have to define your own

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment