Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deterministic-ally get activation_index, fixed identation, added support for python3 #4

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
deterministically get activation index
  • Loading branch information
kendricktan committed Aug 21, 2017
commit f6c054f1f106a8d6d044d29b766b17f57b034e0d
42 changes: 23 additions & 19 deletions finetune.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import torch
from torch.autograd import Variable
from torchvision import models
Expand Down Expand Up @@ -65,30 +66,33 @@ def forward(self, x):
x = module(x)
if isinstance(module, torch.nn.modules.conv.Conv2d):
x.register_hook(self.compute_rank)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.compute_rank is now a function that returns a function (hook). It looks like the pytorch hook will call compute_rank, it will return hook as a function object (but won't run it), and self.filter_ranks won't be computed anywhere.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.compute_rank now returns a function (hook). So when self.compute_rank(activation_index) is called, hook (a partial function with the local variable activation_index) is passed in as the call back function for register_hook.

So when the gradients are updated, hook is called, but doesn't need to calculate the activation_index because it's given when you called (self.compute_ranks(INDEX))

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.
But if so, then wasn't the intention to do:
x.register_hook(self.compute_rank(activation_index))
self.activations.append(x)

Othwerwise x isn't appended to self.activations and can't be used from within hook, and pytorch isn't registering the gradient callback to the partial function from self.compute_rank.

self.activations.append(x)
self.activations.append(self.compute_rank(activation_index))
self.activation_to_layer[activation_index] = layer
activation_index += 1

return self.model.classifier(x.view(x.size(0), -1))

def compute_rank(self, grad):
activation_index = len(self.activations) - self.grad_index - 1
activation = self.activations[activation_index]
values = \
torch.sum((activation * grad), dim=0).\
sum(dim=2).sum(dim=3)[0, :, 0, 0].data

# Normalize the rank by the filter dimensions
values = \
values / (activation.size(0) * activation.size(2)
* activation.size(3))

if activation_index not in self.filter_ranks:
self.filter_ranks[activation_index] = \
torch.FloatTensor(activation.size(1)).zero_().cuda()

self.filter_ranks[activation_index] += values
self.grad_index += 1
def compute_rank(self, activation_index):
# Returns a partial function
# as the callback function
def hook(grad):
activation = self.activations[activation_index]
values = \
torch.sum((activation * grad), dim=0).\
sum(dim=2).sum(dim=3)[0, :, 0, 0].data

# Normalize the rank by the filter dimensions
values = \
values / (activation.size(0) * activation.size(2)
* activation.size(3))

if activation_index not in self.filter_ranks:
self.filter_ranks[activation_index] = \
torch.FloatTensor(activation.size(1)).zero_().cuda()

self.filter_ranks[activation_index] += values
self.grad_index += 1
return hook

def lowest_ranking_filters(self, num):
data = []
Expand Down