Skip to content

Commit

Permalink
better accuracy
Browse files Browse the repository at this point in the history
close #2 #4 #6
  • Loading branch information
SSARCandy committed Jan 6, 2018
1 parent 2563a9e commit 3cff49d
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 137 deletions.
Binary file modified demo/accuracy.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified demo/loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified demo/result.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 2 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
WEIGHT_DECAY = 5e-4
MOMENTUM = 0.9
BATCH_SIZE = [200, 56]
EPOCHS = 10
EPOCHS = 20


source_loader = get_office31_dataloader(case='amazon', batch_size=BATCH_SIZE[0])
Expand Down Expand Up @@ -45,8 +45,7 @@ def train(model, optimizer, epoch, _lambda):
out1, out2 = model(source_data, target_data)

classification_loss = torch.nn.functional.cross_entropy(out1, source_label)
coral = models.CORAL()
coral_loss = coral(out1, out2)
coral_loss = models.CORAL(out1, out2)

sum_loss = _lambda*coral_loss + classification_loss
sum_loss.backward()
Expand Down
49 changes: 12 additions & 37 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,27 @@
CUDA = True if torch.cuda.is_available() else False


def feature_covariance_mat(n, d):
ones_t = torch.ones(n).view(1, -1)
if CUDA:
ones_t = ones_t.cuda()

tmp = ones_t.matmul(d)
covariance_mat = (d.t().matmul(d) - (tmp.t().matmul(tmp) / n)) / (n - 1)
return covariance_mat


def forbenius_norm(mat):
return (mat**2).sum()**0.5


'''
MODELS
'''


class CORAL(Function):
def forward(self, source, target):
d = source.shape[1]
ns, nt = source.shape[0], target.shape[0]
cs = feature_covariance_mat(ns, source)
ct = feature_covariance_mat(nt, target)

self.saved = (source, target, cs, ct, ns, nt, d)

res = forbenius_norm(cs - ct)**2/(4*d*d)
res = torch.FloatTensor([res])
def CORAL(source, target):
d = source.data.shape[1]

return res if CUDA is False else res.cuda()
# source covariance
xm = torch.mean(source, 1, keepdim=True) - source
xc = torch.matmul(torch.transpose(xm, 0, 1), xm)

def backward(self, grad_output):
source, target, cs, ct, ns, nt, d = self.saved
ones_s_t = torch.ones(ns).view(1, -1)
ones_t_t = torch.ones(nt).view(1, -1)
if CUDA:
ones_s_t = ones_s_t.cuda()
ones_t_t = ones_t_t.cuda()
# target covariance
xmt = torch.mean(target, 1, keepdim=True) - target
xct = torch.matmul(torch.transpose(xmt, 0, 1), xmt)

s_gradient = (source.t() - (ones_s_t.matmul(source).t().matmul(ones_s_t)/ns)).t().matmul(cs - ct) / (d*d*(ns - 1))
t_gradient = (target.t() - (ones_t_t.matmul(target).t().matmul(ones_t_t)/nt)).t().matmul(cs - ct) / (d*d*(nt - 1))
t_gradient = -t_gradient
# frobenius norm between source and target
loss = torch.mean(torch.mul((xc - xct), (xc - xct)))
loss = loss/(4*d*4)

return s_gradient*grad_output, t_gradient*grad_output
return loss


class DeepCORAL(nn.Module):
Expand Down
Empty file removed tests/__init__.py
Empty file.
49 changes: 0 additions & 49 deletions tests/fixtures.py

This file was deleted.

48 changes: 0 additions & 48 deletions tests/test.py

This file was deleted.

0 comments on commit 3cff49d

Please sign in to comment.