From c93021e2d54691193c11511f070c55b4564e0150 Mon Sep 17 00:00:00 2001 From: SSARCandy Date: Thu, 26 Apr 2018 23:08:19 +0800 Subject: [PATCH] Clean up redundant code, close #15 --- main.py | 13 ++++--------- models.py | 12 +++++------- 2 files changed, 9 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index fb8ad62..361e206 100644 --- a/main.py +++ b/main.py @@ -22,8 +22,6 @@ def train(model, optimizer, epoch, _lambda): - model.train() - result = [] # Expected size : xs -> (batch_size, 3, 300, 300), ys -> (batch_size) @@ -76,7 +74,7 @@ def train(model, optimizer, epoch, _lambda): return result -def test(model, dataset_loader, e, mode='source'): +def test(model, dataset_loader, e): model.eval() test_loss = 0 correct = 0 @@ -85,9 +83,7 @@ def test(model, dataset_loader, e, mode='source'): data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target) - out1, out2 = model(data, data) - - out = out1 if mode == 'source' else out2 + out, _ = model(data, data) # sum up batch loss test_loss += torch.nn.functional.cross_entropy(out, target, size_average=False).data[0] @@ -133,8 +129,7 @@ def load_pretrained(model): # i.e. 10 times learning rate for the last two fc layers. optimizer = torch.optim.SGD([ {'params': model.sharedNet.parameters()}, - {'params': model.source_fc.parameters(), 'lr': 10*LEARNING_RATE}, - {'params': model.target_fc.parameters(), 'lr': 10*LEARNING_RATE} + {'params': model.fc.parameters(), 'lr': 10*LEARNING_RATE}, ], lr=LEARNING_RATE, momentum=MOMENTUM) if CUDA: @@ -163,7 +158,7 @@ def load_pretrained(model): training_statistic.append(res) test_source = test(model, source_loader, e) - test_target = test(model, target_loader, e, mode='target') + test_target = test(model, target_loader, e) testing_s_statistic.append(test_source) testing_t_statistic.append(test_target) diff --git a/models.py b/models.py index 2442193..6dae31d 100644 --- a/models.py +++ b/models.py @@ -23,7 +23,7 @@ def CORAL(source, target): # frobenius norm between source and target loss = torch.mean(torch.mul((xc - xct), (xc - xct))) - loss = loss/(4*d*4) + loss = loss/(4*d*d) return loss @@ -32,19 +32,17 @@ class DeepCORAL(nn.Module): def __init__(self, num_classes=1000): super(DeepCORAL, self).__init__() self.sharedNet = AlexNet() - self.source_fc = nn.Linear(4096, num_classes) - self.target_fc = nn.Linear(4096, num_classes) + self.fc = nn.Linear(4096, num_classes) # initialize according to CORAL paper experiment - self.source_fc.weight.data.normal_(0, 0.005) - self.target_fc.weight.data.normal_(0, 0.005) + self.fc.weight.data.normal_(0, 0.005) def forward(self, source, target): source = self.sharedNet(source) - source = self.source_fc(source) + source = self.fc(source) target = self.sharedNet(target) - target = self.source_fc(target) + target = self.fc(target) return source, target