Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Hessian-Vector Products w.r.t. input #343

Open
MaddyThakker opened this issue Jan 2, 2025 · 1 comment
Open

Batch Hessian-Vector Products w.r.t. input #343

MaddyThakker opened this issue Jan 2, 2025 · 1 comment

Comments

@MaddyThakker
Copy link

MaddyThakker commented Jan 2, 2025

What I’m Doing Now (in PyTorch)

I want to compute Hessian-vector products (HVPs) w.r.t input for a classification problem that uses NLLLoss or CrossEntropyLoss. My main attempts include:

  1. functorch.jvp + vmap

    from functorch import grad, jvp, vmap
    
    def single_sample_hvp(model, x_single, y_single, v_single):
        def loss_fn(x):
            out = model(x.unsqueeze(0))  
            loss = criterion(out, y_single.unsqueeze(0))
            return loss
    
        grad_of_loss = grad(loss_fn)
        _, hvp_val = jvp(grad_of_loss, (x_single,), (v_single,))  # forward-mode
        return hvp_val
    
    def batched_hvp(model, x_batch, y_batch, v_batch):
        return vmap(single_sample_hvp, in_dims=(None, 0, 0, 0))(
            model, x_batch, y_batch, v_batch
        )

    However, I run into:

    NotImplementedError: Trying to use forward AD with _log_softmax_backward_data ...
    

    because built-in classification losses do not support forward-mode AD.

  2. Double-Backward for HVP
    I can compute HVPs using standard backward-mode AD twice (which works for NLLLoss), for each sample in a Python loop:

    import torch
    
    def hvp_single_sample(model, criterion, x, y, v):
        out = model(x.unsqueeze(0))
        loss = criterion(out, y.unsqueeze(0))
        grad_x = torch.autograd.grad(loss, x, create_graph=True)[0]
        hvp = torch.autograd.grad(grad_x, x, grad_outputs=v, retain_graph=True)[0]
        return hvp
    
    # Then loop over the batch:
    def hvp_batch(model, criterion, X, Y, V):
        hvps = []
        for i in range(X.size(0)):
            x_i = X[i].requires_grad_()
            y_i = Y[i]
            v_i = V[i]
            hvps.append(hvp_single_sample(model, criterion, x_i, y_i, v_i))
        return torch.stack(hvps, dim=0)

    This works, but is slow for large batches because I can’t vmap it easily due to create_graph=True.

What I’d Like to Achieve

  1. Compute per-sample (or aggregated) Hessian-vector products for NLLLoss / CrossEntropyLoss across a large batch.
  2. Bypass the slow Python loop required by double-backward.
  3. See if there is an API or recommended usage in BackPACK that can handle this scenario efficiently.

Questions

  • Does BackPACK already have a built-in approach for HVPs w.r.t. input?
  • Any suggestions on how to implement a “batched” approach without having interference of datapoints of same batch —— basically make the for loop fast?

Thanks in advance for any guidance!

@MaddyThakker MaddyThakker changed the title Batch Hessian-Vector Products Batch Hessian-Vector Products w.r.t. input Jan 2, 2025
@MaddyThakker
Copy link
Author

@f-dangel any ideas?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant