Test step to handle non-scalar outputs #4609
Labels
feature
Is an improvement or enhancement
help wanted
Open to be worked on
won't fix
This will not be worked on
🚀 Feature
Handle output from test loop not being a single value.
Motivation
I often need to use a callback to do some processing on test values (to make plots, etc.), which I like to separate from the core module code. In this case, I would like to use
on_test_batch_end
to build a list of predicted values, calculated in the coretest_step
.Pitch
To make this work, I need to output an object from
test_step
, something like{"loss": loss, "predictions": preds, "truth": truth}
. However, the test loop runs.item()
on any torch tensors, which doesn't work if the outputs are non-scalar. It would be cool if the test loop handled this situation, otherwise the output from test loop (and therefore any inputs to callbacks) is limited to scalar tensors.The text was updated successfully, but these errors were encountered: