You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I want to compute Hessian-vector products (HVPs) w.r.t input for a classification problem that uses NLLLoss or CrossEntropyLoss. My main attempts include:
What I’m Doing Now (in PyTorch)
I want to compute Hessian-vector products (HVPs) w.r.t input for a classification problem that uses
NLLLoss
orCrossEntropyLoss
. My main attempts include:functorch.jvp + vmap
However, I run into:
because built-in classification losses do not support forward-mode AD.
Double-Backward for HVP
I can compute HVPs using standard backward-mode AD twice (which works for NLLLoss), for each sample in a Python loop:
This works, but is slow for large batches because I can’t
vmap
it easily due tocreate_graph=True
.What I’d Like to Achieve
NLLLoss
/CrossEntropyLoss
across a large batch.Questions
Thanks in advance for any guidance!
The text was updated successfully, but these errors were encountered: