-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make this work for TripletMarginLoss #69
Comments
(update: call Hi @vinod1234567890. This question is similar to the one filed in #61, you can check it out for further details about the solution. But since triplet sampler is a bit different to those typical data loaders, here is an explanation and example for it. The following snippet is the current implementation of forward pass operation in pytorch-lr-finder/torch_lr_finder/lr_finder.py Lines 371 to 378 in acc5e7e
Since we know that a triplet sampler would return 3 values: query, positive image, and a negative image, we want to make it work in the following format: # Desired format:
(query, positive_image, negative_image), labels = next(train_iter)
outputs = self.model((query, positive_image, negative_image))
loss = self.criterion(outputs, labels) Let's compare those 2 above-mentioned snippets and try to use wrappers (the tip mentioned in #61) in order to keep the code unchanged: # The actual meaning of each object instance:
inputs, labels = next(train_iter) # inputs: (query, positive_image, negative_image), labels: None
outputs = self.model(inputs) # model: ModelWrapper()
loss = self.criterion(outputs, labels) # criterion: LossFunctionWrapper() Now, we can try to implement those wrappers: from torch_lr_finder import LRFinder, TrainDataLoaderIter
# See also: https://github.com/davidtvs/pytorch-lr-finder/blob/acc5e7e/torch_lr_finder/lr_finder.py#L31-L41
class MyTrainDataLoaderIter(TrainDataLoaderIter):
def inputs_labels_from_batch(self, batch_data):
# Since a triplet sampler returns 3 images, here we use `None` as `labels` in order to follow the output format of returning `inputs, labels`
return batch_data, None
# NOTE: This is not a fully functional DeepRanking network, pleace replace this with your implementation
class DeepRankingNet(nn.Module):
def __init__(self):
super().__init__()
self.embedder = ...
def forward(self, query, positive_image, negative_image):
query_embedding = self.embedder(query)
positive_embedding = self.embedder(positive_image)
negative_embedding = self.embedder(negative_image)
return query_embedding, positive_embedding, negative_embedding
# NOTE: Use this wrapper to unpack input data
class ModelWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, inputs):
query, positive_image, negative_image = inputs # unpack
outputs = self.model(query, positive_image, negative_image)
return outputs
# NOTE: Use this wrapper to unpack outputs
class LossFunctionWrapper(nn.Module):
def __init__(self, loss_func):
super().__init__()
self.loss_func = loss_func
def forward(self, outputs, labels):
query_embedding, positive_embedding, negative_embedding = outputs
return self.loss_func(
anchor=query_embedding,
positive=positive_embedding,
negative=negative_embedding
)
if __name__ == '__main__':
# Prepare your own dataset and data loader
data_loader = TripletSampler(YOUR_IMAGE_DATASET)
model = DeepRanking()
loss_func = nn.TripletMarginLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.5)
# Create wrappers
trainloader_wrapper = MyTrainDataLoaderIter(data_loader)
model_wrapper = ModelWrapper(model)
loss_func_wrapper = LossFunctionWrapper(loss_func)
# Run LRFinder
lr_finder = LRFinder(model_wrapper, optimizer, loss_func_wrapper, device='cuda')
lr_finder.range_test(
trainloader_wrapper, end_lr=1, num_iter=10, step_mode='exp',
start_lr=1e-5
)
lr_finder.plot()
lr_finder.reset() And that's it. Hope this helps! But if you have further questions, please feel free to let me know. :) |
thanks! |
Thanks for the great discussion! I just want to add that I ran into this error when customizing my wrappers: And it turns out that you need to initialize super class first with |
@sivannavis Thanks! 😄 |
Can we make this work for nn.TripletMarginLoss?
where the dataset object returns
query, positive_image, negative_image
which are passed to the model one-by-one and the three resultant embeddings are passed to the loss function (https://medium.com/@akarshzingade/image-similarity-using-deep-ranking-c1bd83855978).The text was updated successfully, but these errors were encountered: