Skip to content

Commit

Permalink
Clean up redundant code, close #15
Browse files Browse the repository at this point in the history
  • Loading branch information
SSARCandy committed Apr 26, 2018
1 parent c6f55b6 commit c93021e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 16 deletions.
13 changes: 4 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@


def train(model, optimizer, epoch, _lambda):
model.train()

result = []

# Expected size : xs -> (batch_size, 3, 300, 300), ys -> (batch_size)
Expand Down Expand Up @@ -76,7 +74,7 @@ def train(model, optimizer, epoch, _lambda):
return result


def test(model, dataset_loader, e, mode='source'):
def test(model, dataset_loader, e):
model.eval()
test_loss = 0
correct = 0
Expand All @@ -85,9 +83,7 @@ def test(model, dataset_loader, e, mode='source'):
data, target = data.cuda(), target.cuda()

data, target = Variable(data, volatile=True), Variable(target)
out1, out2 = model(data, data)

out = out1 if mode == 'source' else out2
out, _ = model(data, data)

# sum up batch loss
test_loss += torch.nn.functional.cross_entropy(out, target, size_average=False).data[0]
Expand Down Expand Up @@ -133,8 +129,7 @@ def load_pretrained(model):
# i.e. 10 times learning rate for the last two fc layers.
optimizer = torch.optim.SGD([
{'params': model.sharedNet.parameters()},
{'params': model.source_fc.parameters(), 'lr': 10*LEARNING_RATE},
{'params': model.target_fc.parameters(), 'lr': 10*LEARNING_RATE}
{'params': model.fc.parameters(), 'lr': 10*LEARNING_RATE},
], lr=LEARNING_RATE, momentum=MOMENTUM)

if CUDA:
Expand Down Expand Up @@ -163,7 +158,7 @@ def load_pretrained(model):
training_statistic.append(res)

test_source = test(model, source_loader, e)
test_target = test(model, target_loader, e, mode='target')
test_target = test(model, target_loader, e)
testing_s_statistic.append(test_source)
testing_t_statistic.append(test_target)

Expand Down
12 changes: 5 additions & 7 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def CORAL(source, target):

# frobenius norm between source and target
loss = torch.mean(torch.mul((xc - xct), (xc - xct)))
loss = loss/(4*d*4)
loss = loss/(4*d*d)

return loss

Expand All @@ -32,19 +32,17 @@ class DeepCORAL(nn.Module):
def __init__(self, num_classes=1000):
super(DeepCORAL, self).__init__()
self.sharedNet = AlexNet()
self.source_fc = nn.Linear(4096, num_classes)
self.target_fc = nn.Linear(4096, num_classes)
self.fc = nn.Linear(4096, num_classes)

# initialize according to CORAL paper experiment
self.source_fc.weight.data.normal_(0, 0.005)
self.target_fc.weight.data.normal_(0, 0.005)
self.fc.weight.data.normal_(0, 0.005)

def forward(self, source, target):
source = self.sharedNet(source)
source = self.source_fc(source)
source = self.fc(source)

target = self.sharedNet(target)
target = self.source_fc(target)
target = self.fc(target)
return source, target


Expand Down

0 comments on commit c93021e

Please sign in to comment.