-
Notifications
You must be signed in to change notification settings - Fork 303
/
Copy pathretrieval_vid_mplug.py
273 lines (219 loc) · 10.5 KB
/
retrieval_vid_mplug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
import argparse
import os
import ruamel_yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path
from models.tokenization_bert import BertTokenizer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader
from models.model_retrieval_mplug import MPLUG
from models.vit import interpolate_pos_embed, resize_pos_embed
import utils
from dataset.video_dataset import VideoDataset
@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, config):
# test
model.eval()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Evaluation:'
print('Computing features for evaluation...')
start_time = time.time()
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_feats = []
text_embeds = []
text_atts = []
for i in range(0, num_text, text_bs):
text = texts[i: min(num_text, i + text_bs)]
text_input = tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(
device)
text_output = model.text_encoder(text_input.input_ids, attention_mask=text_input.attention_mask)
text_feat = text_output.last_hidden_state
text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:, 0, :]))
text_embeds.append(text_embed)
text_feats.append(text_feat)
text_atts.append(text_input.attention_mask)
text_embeds = torch.cat(text_embeds, dim=0)
text_feats = torch.cat(text_feats, dim=0)
text_atts = torch.cat(text_atts, dim=0)
video_feats = []
video_embeds = []
for video, video_id in data_loader:
B, N, C, W, H = video.size()
video = video.view(-1, C, W, H)
video = video.to(device, non_blocking=True)
video_feat = model.visual_encoder.visual(video, skip_last_layer=True)
video_feat = model.visn_layer_norm(model.visn_fc(video_feat))
video_embed = model.vision_proj(video_feat[:, 0, :])
video_embed = video_embed.view(B, N, -1).mean(dim=1)
video_embed = F.normalize(video_embed, dim=-1)
video_feat = video_feat.view(B, -1, video_feat.shape[-1])
video_feats.append(video_feat.cpu())
video_embeds.append(video_embed)
video_feats = torch.cat(video_feats, dim=0)
video_embeds = torch.cat(video_embeds, dim=0)
sims_matrix = video_embeds @ text_embeds.t()
score_matrix_v2t = torch.full((len(texts), len(texts)), -100.0).to(device)
num_tasks = utils.get_world_size()
rank = utils.get_rank()
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
encoder_output = video_feats[start + i].repeat(config['k_test'], 1, 1).to(device, non_blocking=True)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device, non_blocking=True)
_, output = model.fusion_encoder(encoder_embeds=text_feats[topk_idx],
attention_mask=text_atts[topk_idx],
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=False,
)
score = model.itm_head(output[:, 0, :])[:, 1]
score_matrix_v2t[start + i, topk_idx] = score + topk_sim
sims_matrix = sims_matrix.t()
score_matrix_t2v = torch.full((len(texts), len(texts)), -100.0).to(device)
step = sims_matrix.size(0) // num_tasks + 1
start = rank * step
end = min(sims_matrix.size(0), start + step)
for i, sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
encoder_output = video_feats[topk_idx].to(device, non_blocking=True)
encoder_att = torch.ones(encoder_output.size()[:-1], dtype=torch.long).to(device, non_blocking=True)
_, output = model.fusion_encoder(encoder_embeds=text_feats[start + i].repeat(config['k_test'], 1, 1),
attention_mask=text_atts[start + i].repeat(config['k_test'], 1),
encoder_hidden_states=encoder_output,
encoder_attention_mask=encoder_att,
return_dict=False,
)
score = model.itm_head(output[:, 0, :])[:, 1]
score_matrix_t2v[start + i, topk_idx] = score + topk_sim
if args.distributed:
dist.barrier()
torch.distributed.all_reduce(score_matrix_v2t, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(score_matrix_t2v, op=torch.distributed.ReduceOp.SUM)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Evaluation time {}'.format(total_time_str))
return score_matrix_v2t.cpu().numpy(), score_matrix_t2v.cpu().numpy()
@torch.no_grad()
def itm_eval(scores_v2t, scores_t2v, txt2vmg, vid2txt):
# Video->Text
ranks = np.zeros(scores_v2t.shape[0])
for index, score in enumerate(scores_v2t):
inds = np.argsort(score)[::-1]
ranks[index] = np.where(inds == vid2txt[index])[0][0]
# Compute metrics
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
# Text->Video
ranks = np.zeros(scores_t2v.shape[0])
for index, score in enumerate(scores_t2v):
inds = np.argsort(score)[::-1]
ranks[index] = np.where(inds == txt2vmg[index])[0][0]
mdR = np.median(ranks + 1)
# Compute metrics
vr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
vr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
vr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
tr_mean = (tr1 + tr5 + tr10) / 3
vr_mean = (vr1 + vr5 + vr10) / 3
r_mean = (tr_mean + vr_mean) / 2
eval_result = {'txt_r1': tr1,
'txt_r5': tr5,
'txt_r10': tr10,
'txt_r_mean': tr_mean,
'vid_r1': vr1,
'vid_r5': vr5,
'vid_r10': vr10,
'vid_r_mean': vr_mean,
'vid_mdR': mdR,
'r_mean': r_mean}
return eval_result
def main(args, config):
utils.init_distributed_mode(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
#### Dataset ####
print("Creating retrieval dataset")
test_dataset = VideoDataset(config['video_root'], config['ann_root'], num_frm=config['num_frm_test'],
max_img_size=config['image_size'], frm_sampling_strategy='uniform')
test_loader = DataLoader(
test_dataset,
batch_size=config['batch_size'],
num_workers=4,
pin_memory=True,
drop_last=False,
shuffle=False,
)
#### Model ####
print("Creating model")
tokenizer = BertTokenizer.from_pretrained(args.text_encoder)
model = MPLUG(config=config, tokenizer=tokenizer)
if args.checkpoint:
checkpoint = torch.load(args.checkpoint, map_location='cpu')
try:
state_dict = checkpoint['model']
except:
state_dict = checkpoint['module']
# reshape positional embedding to accomodate for image resolution change
if config["clip_name"] == "ViT-B-16":
num_patches = int(config["image_res"] * config["image_res"] / (16 * 16))
elif config["clip_name"] == "ViT-L-14":
num_patches = int(config["image_res"] * config["image_res"] / (14 * 14))
pos_embed = nn.Parameter(torch.zeros(num_patches + 1, 768).float())
pos_embed = resize_pos_embed(state_dict['visual_encoder.visual.positional_embedding'].unsqueeze(0),
pos_embed.unsqueeze(0))
state_dict['visual_encoder.visual.positional_embedding'] = pos_embed
for key in list(state_dict.keys()):
if ('fusion' in key or 'bert' in key) and 'decode' not in key:
encoder_key = key.replace('fusion.', '').replace('bert.', '')
state_dict[encoder_key] = state_dict[key]
del state_dict[key]
msg = model.load_state_dict(state_dict, strict=False)
print('load checkpoint from %s' % args.checkpoint)
print(msg)
model = model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
score_v2t, score_t2v, = evaluation(model_without_ddp, test_loader, tokenizer, device, config)
if utils.is_main_process():
test_result = itm_eval(score_v2t, score_t2v, test_loader.dataset.txt2video, test_loader.dataset.video2txt)
print(test_result)
log_stats = {**{f'{k}': v for k, v in test_result.items()}, }
with open(os.path.join(args.output_dir, "test_result.txt"), "a") as f:
f.write(json.dumps(log_stats) + "\n")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='./configs/retrieval_msrvtt.yaml')
parser.add_argument('--output_dir', default='output/Retrieval_msrvtt')
parser.add_argument('--device', default='cuda')
parser.add_argument('--text_encoder', default='bert-base-uncased')
parser.add_argument('--checkpoint', default='')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', default=True, type=bool)
args = parser.parse_args()
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
config['text_encoder'] = args.text_encoder
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
main(args, config)