Skip to content

PyTorch-adapt Multi-label classifications #100

Discussion options

You must be logged in to vote

You need to change CLossHook's loss function to torch.nn.BCEWithLogitsLoss. The hook you're using should allow you to pass in a custom c_hook or something like that. I can help more with this if you tell me which hook you're using.

An alternative is to "monkey patch" the init function of CLossHook:

import torch
from pytorch_adapt.hooks import CLossHook

def init_modifier(method):
    def modify(self, *args, **kwargs):
        method(self, *args, **kwargs)
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
    return modify

CLossHook.__init__ = init_modifier(CLossHook.__init__)

Then every hook that uses CLossHook will now be using torch.nn.BCEWithLogitsLoss.

You might get a type error du…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@AlexandrByzov
Comment options

Answer selected by AlexandrByzov
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants