-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
81 lines (65 loc) · 3.13 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision import transforms
from clef.beam_search import CaptionGenerator
import torch.nn as nn
class CaptionModel(nn.Module):
def __init__(self, cnn, vocab, embedding_size=256, rnn_size=256, num_layers=2,
share_embedding_weights=False):
super(CaptionModel, self).__init__()
self.vocab = vocab
self.cnn = cnn
self.cnn.fc = nn.Linear(self.cnn.fc.in_features, embedding_size)
self.rnn = nn.LSTM(embedding_size, rnn_size, num_layers=num_layers)
self.classifier = nn.Linear(rnn_size, len(vocab))
self.embedder = nn.Embedding(len(self.vocab), embedding_size)
if share_embedding_weights:
self.embedder.weight = self.classifier.weight
def forward(self, imgs, captions, lengths):
embeddings = self.embedder(captions)
img_feats = self.cnn(imgs).unsqueeze(0)
embeddings = torch.cat([img_feats, embeddings], 0)
packed_embeddings = pack_padded_sequence(embeddings, lengths)
feats, state = self.rnn(packed_embeddings)
pred = self.classifier(feats[0])
return pred, state
def generate(self, img, scale_size=256, crop_size=224,
eos_token='EOS', beam_size=3,
max_caption_length=20,
length_normalization_factor=0.0):
cap_gen = CaptionGenerator(embedder=self.embedder,
rnn=self.rnn,
classifier=self.classifier,
eos_id=self.vocab.index(eos_token),
beam_size=beam_size,
max_caption_length=max_caption_length,
length_normalization_factor=length_normalization_factor)
if next(self.parameters()).is_cuda:
img = img.cuda()
img = Variable(img.unsqueeze(0), volatile=True)
img_feats = self.cnn(img).unsqueeze(0)
sentences, score = cap_gen.beam_search(img_feats)
sentences = [' '.join([self.vocab[idx] for idx in sent])
for sent in sentences]
return sentences
def save_checkpoint(self, filename):
torch.save({'embedder_dict': self.embedder.state_dict(),
'rnn_dict': self.rnn.state_dict(),
'cnn_dict': self.cnn.state_dict(),
'classifier_dict': self.classifier.state_dict(),
'vocab': self.vocab,
'model': self},
filename)
def load_checkpoint(self, filename):
cpnt = torch.load(filename)
if 'cnn_dict' in cpnt:
self.cnn.load_state_dict(cpnt['cnn_dict'])
self.embedder.load_state_dict(cpnt['embedder_dict'])
self.rnn.load_state_dict(cpnt['rnn_dict'])
self.classifier.load_state_dict(cpnt['classifier_dict'])
def finetune_cnn(self, allow=True):
for p in self.cnn.parameters():
p.requires_grad = allow
for p in self.cnn.fc.parameters():
p.requires_grad = True