Skip to content

Commit 42391fd

Browse files
committed
Add VQA2, VisualGenome, FBResNet152 (for pytorch)
Factory - vqa models, convnets and vqa datasets can be created via factories VQA 2.0 - VQA2(AbstractVQA) added VisualGenome - VisualGenome(AbstractVQADataset) added for merging with VQA datasets - VisualGenomeImages(AbstractImagesDataset) added to extract features - `extract.py` now allows to extract VisualGenome features Variable features size - `extract.py` now allows to extract from images of size != 448 via cli arg `--size` - FeaturesDataset now have an optional `opt['size']` parameter FBResNet152 - `convnets.py` provides support for external pretrained-models as well as ResNets from torchvision - especially FBResNet152 is the porting of fbresnet152torch from torch7 used until now
1 parent 57752d6 commit 42391fd

23 files changed

+1070
-126
lines changed

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
[submodule "vqa/external/skip-thoughts.torch"]
55
path = vqa/external/skip-thoughts.torch
66
url = https://github.com/Cadene/skip-thoughts.torch.git
7+
[submodule "vqa/external/pretrained-models.pytorch"]
8+
path = vqa/external/pretrained-models.pytorch
9+
url = https://github.com/Cadene/pretrained-models.pytorch.git

README.md

+55-20
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
# Visual Question Answering in pytorch
22

3-
This repo was made by [Remi Cadene](http://remicadene.com) (LIP6) and [Hedi Ben-Younes](https://twitter.com/labegne) (LIP6-Heuritech), two PhD Students working on VQA at [UPMC-LIP6](http://lip6.fr) and their professors [Matthieu Cord](http://webia.lip6.fr/~cord) (LIP6) and [Nicolas Thome](http://webia.lip6.fr/~thomen) (LIP6-CNAM). We developped this code in the frame of a research paper called [MUTAN: Multimodal Tucker Fusion for VQA](https://arxiv.org/abs/1705.06676) which is (as far as we know) the current state-of-the-art on the [VQA-1 dataset](http://visualqa.org).
3+
This repo was made by [Remi Cadene](http://remicadene.com) (LIP6) and [Hedi Ben-Younes](https://twitter.com/labegne) (LIP6-Heuritech), two PhD Students working on VQA at [UPMC-LIP6](http://lip6.fr) and their professors [Matthieu Cord](http://webia.lip6.fr/~cord) (LIP6) and [Nicolas Thome](http://webia.lip6.fr/~thomen) (LIP6-CNAM). We developped this code in the frame of a research paper called [MUTAN: Multimodal Tucker Fusion for VQA](https://arxiv.org/abs/1705.06676) which is (as far as we know) the current state-of-the-art on the [VQA 1.0 dataset](http://visualqa.org).
44

55
The goal of this repo is two folds:
66
- to make it easier to reproduce our results,
77
- to provide an efficient and modular code base to the community for further research on other VQA datasets.
88

99
If you have any questions about our code or model, don't hesitate to contact us or to submit any issues. Pull request are welcome!
1010

11+
#### News:
12+
13+
- coming soon: pretrained models on VQA2, features of FBResnet152, web app demo
14+
- 18th july 2017: VQA2, VisualGenome, FBResnet152 (for pytorch) added
15+
- 16th july 2017: paper accepted at ICCV2017
16+
- 30th may 2017: poster accepted at CVPR2017 (VQA Workshop)
17+
1118
#### Summary:
1219

1320
* [Introduction](#introduction)
@@ -27,7 +34,10 @@ If you have any questions about our code or model, don't hesitate to contact us
2734
* [Models](#models)
2835
* [Quick examples](#quick-examples)
2936
* [Extract features from COCO](#extract-features-from-coco)
30-
* [Train models on VQA](#train-models-on-vqa)
37+
* [Extract features from VisualGenome](#extract-features-from-visualgenome)
38+
* [Train models on VQA 1.0](#train-models-on-vqa-1-0)
39+
* [Train models on VQA 2.0](#train-models-on-vqa-2-0)
40+
* [Train models on VQA + VisualGenome](#train-models-on-vqa-2-0)
3141
* [Monitor training](#monitor-training)
3242
* [Restart training](#restart-training)
3343
* [Evaluate models on VQA](#evaluate-models-on-vqa)
@@ -108,7 +118,7 @@ Our code has two external dependencies:
108118
Data will be automaticaly downloaded and preprocessed when needed. Links to data are stored in `vqa/datasets/vqa.py` and `vqa/datasets/coco.py`.
109119

110120

111-
## Reproducing results
121+
## Reproducing results on VQA 1.0
112122

113123
### Features
114124

@@ -173,7 +183,7 @@ To obtain test and testdev results, you will need to zip your result json file (
173183
|
174184
├── train.py # train & eval models
175185
├── eval_res.py # eval results files with OpenEnded metric
176-
├── extract.pt # extract features from coco with CNNs
186+
├── extract.py # extract features from coco with CNNs
177187
└── visu.py # visualize logs and monitor training
178188
```
179189

@@ -189,16 +199,15 @@ You can easly add new options in your custom yaml file if needed. Also, if you w
189199

190200
### Datasets
191201

192-
We currently provide three datasets:
202+
We currently provide four datasets:
193203

194204
- [COCOImages](http://mscoco.org/) currently used to extract features, it comes with three datasets: trainset, valset and testset
195-
- COCOFeatures used by any VQA datasets
196-
- [VQA](http://www.visualqa.org/vqa_v1_download.html) comes with four datasets: trainset, valset, testset (including test-std and test-dev) and "trainvalset" (concatenation of trainset and valset)
205+
- [VisualGenomeImages]() currently used to extract features, it comes with one split: trainset
206+
- [VQA 1.0](http://www.visualqa.org/vqa_v1_download.html) comes with four datasets: trainset, valset, testset (including test-std and test-dev) and "trainvalset" (concatenation of trainset and valset)
207+
- [VQA 2.0](http://www.visualqa.org) same but twice bigger (however same images than VQA 1.0)
197208

198209
We plan to add:
199210

200-
- [VisualGenome](http://visualgenome.org/)
201-
- [VQA2](http://www.visualqa.org/)
202211
- [CLEVR](http://cs.stanford.edu/people/jcjohns/clevr/)
203212

204213
### Models
@@ -245,7 +254,16 @@ CUDA_VISIBLE_DEVICES=0 python extract.py
245254
CUDA_VISIBLE_DEVICES=1,2 python extract.py
246255
```
247256

248-
### Train models on VQA
257+
### Extract features from VisualGenome
258+
259+
Same here, but only train is available:
260+
261+
```
262+
python extract.py --dataset vgenome --dir_data data/vgenome --data_split train
263+
```
264+
265+
266+
### Train models on VQA 1.0
249267

250268
Display help message, selected options and run default. The needed data will be automaticaly downloaded and processed using the options in `options/default.yaml`.
251269

@@ -258,19 +276,19 @@ python train.py
258276
Run a MutanNoAtt model with default options.
259277

260278
```
261-
python train.py --path_opt options/vqa/mutan_noatt.yaml --dir_logs logs/vqa/mutan_noatt
279+
python train.py --path_opt options/vqa/mutan_noatt.yaml --dir_logs logs/vqa/mutan_noatt_train
262280
```
263281

264282
Run a MutanAtt model on the trainset and evaluate on the valset after each epoch.
265283

266284
```
267-
python train.py --vqa_trainsplit train --path_opt options/vqa/mutan_att.yaml
285+
python train.py --vqa_trainsplit train --path_opt options/vqa/mutan_att_trainval.yaml
268286
```
269287

270288
Run a MutanAtt model on the trainset and valset (by default) and run throw the testset after each epoch (produce a results file that you can submit to the evaluation server).
271289

272290
```
273-
python train.py --vqa_trainsplit trainval --path_opt options/vqa/mutan_att.yaml
291+
python train.py --vqa_trainsplit trainval --path_opt options/vqa/mutan_att_trainval.yaml
274292
```
275293

276294
### Monitor training
@@ -301,6 +319,22 @@ Create a visualization of multiple experiments to compare them or monitor them l
301319
python visu.py --dir_logs logs/vqa/mutan_noatt,logs/vqa/mutan_att
302320
```
303321

322+
### Train models on VQA 2.0
323+
324+
See options of [vqa2/mutan_att_trainval](https://github.com/Cadene/vqa.pytorch/blob/master/options/vqa2/mutan_att_trainval.yaml):
325+
326+
```
327+
python train.py --path_opt options/vqa2/mutan_att_trainval.yaml
328+
```
329+
330+
### Train models on VQA (1.0 or 2.0) + VisualGenome
331+
332+
See options of [vqa2/mutan_att_trainval_vg](https://github.com/Cadene/vqa.pytorch/blob/master/options/vqa2/mutan_att_trainval_vg.yaml):
333+
334+
```
335+
python train.py --path_opt options/vqa2/mutan_att_trainval_vg.yaml
336+
```
337+
304338
### Restart training
305339

306340
Restart the model from the last checkpoint.
@@ -329,13 +363,14 @@ Please cite the arXiv paper if you use Mutan in your work:
329363

330364
```
331365
@article{benyounescadene2017mutan,
332-
title={MUTAN: Multimodal Tucker Fusion for Visual Question Answering},
333-
author={Hedi Ben-Younes and
334-
R{\'{e}}mi Cad{\`{e}}ne and
335-
Nicolas Thome and
336-
Matthieu Cord}},
337-
journal={arXiv preprint arXiv:1705.06676},
338-
year={2017}
366+
author = {Hedi Ben-Younes and
367+
R{\'{e}}mi Cad{\`{e}}ne and
368+
Nicolas Thome and
369+
Matthieu Cord},
370+
title = {MUTAN: Multimodal Tucker Fusion for Visual Question Answering},
371+
journal = {ICCV},
372+
year = {2017},
373+
url = {http://arxiv.org/abs/1705.06676}
339374
}
340375
```
341376

extract.py

+52-32
Original file line numberDiff line numberDiff line change
@@ -8,60 +8,73 @@
88
import torch.nn as nn
99
import torch.nn.parallel
1010
import torch.backends.cudnn as cudnn
11+
from torch.autograd import Variable
1112

1213
import torchvision.transforms as transforms
1314
import torchvision.datasets as datasets
14-
import torchvision.models as models
1515

16-
import vqa.datasets.coco as coco
16+
import vqa.models.convnets as convnets
17+
import vqa.datasets as datasets
1718
from vqa.lib.dataloader import DataLoader
18-
from vqa.models.utils import ResNet
1919
from vqa.lib.logger import AvgMeter
2020

21-
model_names = sorted(name for name in models.__dict__
22-
if name.islower() and name.startswith("resnet")
23-
and callable(models.__dict__[name]))
24-
2521
parser = argparse.ArgumentParser(description='Extract')
26-
parser.add_argument('--dir_data', default='data/coco', metavar='DIR',
27-
help='dir dataset: mscoco or visualgenome')
22+
parser.add_argument('--dataset', default='coco',
23+
choices=['coco', 'vgenome'],
24+
help='dataset type: coco (default) | vgenome')
25+
parser.add_argument('--dir_data', default='data/coco',
26+
help='dir dataset to download or/and load images')
2827
parser.add_argument('--data_split', default='train', type=str,
2928
help='Options: (default) train | val | test')
30-
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet152',
31-
choices=model_names,
29+
parser.add_argument('--arch', '-a', default='resnet152',
30+
choices=convnets.model_names,
3231
help='model architecture: ' +
33-
' | '.join(model_names) +
34-
' (default: resnet152)')
35-
parser.add_argument('--workers', default=4, type=int, metavar='N',
36-
help='number of data loading workers (default: 8)')
37-
parser.add_argument('--batch_size', '-b', default=80, type=int, metavar='N',
32+
' | '.join(convnets.model_names) +
33+
' (default: fbresnet152)')
34+
parser.add_argument('--workers', default=4, type=int,
35+
help='number of data loading workers (default: 4)')
36+
parser.add_argument('--batch_size', '-b', default=80, type=int,
3837
help='mini-batch size (default: 80)')
3938
parser.add_argument('--mode', default='both', type=str,
4039
help='Options: att | noatt | (default) both')
40+
parser.add_argument('--size', default=448, type=int,
41+
help='Image size (448 for noatt := avg pooling to get 224) (default:448)')
4142

4243

4344
def main():
45+
global args
4446
args = parser.parse_args()
4547

4648
print("=> using pre-trained model '{}'".format(args.arch))
47-
model = models.__dict__[args.arch](pretrained=True)
48-
model = ResNet(model, False)
49-
model = nn.DataParallel(model).cuda()
49+
model = convnets.factory({'arch':args.arch}, cuda=True, data_parallel=True)
5050

51-
#extract_name = 'arch,{}_layer,{}_resize,{}'.format()
52-
extract_name = 'arch,{}'.format(args.arch)
51+
extract_name = 'arch,{}_size,{}'.format(args.arch, args.size)
5352

54-
#dir_raw = os.path.join(args.dir_data, 'raw')
5553
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
5654
std=[0.229, 0.224, 0.225])
5755

58-
dataset = coco.COCOImages(args.data_split, dict(dir=args.dir_data),
59-
transform=transforms.Compose([
60-
transforms.Scale(448),
61-
transforms.CenterCrop(448),
62-
transforms.ToTensor(),
63-
normalize,
64-
]))
56+
if args.dataset == 'coco':
57+
if 'coco' not in args.dir_data:
58+
raise ValueError('"coco" string not in dir_data')
59+
dataset = datasets.COCOImages(args.data_split, dict(dir=args.dir_data),
60+
transform=transforms.Compose([
61+
transforms.Scale(args.size),
62+
transforms.CenterCrop(args.size),
63+
transforms.ToTensor(),
64+
normalize,
65+
]))
66+
elif args.dataset == 'vgenome':
67+
if args.data_split != 'train':
68+
raise ValueError('train split is required for vgenome')
69+
if 'vgenome' not in args.dir_data:
70+
raise ValueError('"vgenome" string not in dir_data')
71+
dataset = datasets.VisualGenomeImages(args.data_split, dict(dir=args.dir_data),
72+
transform=transforms.Compose([
73+
transforms.Scale(args.size),
74+
transforms.CenterCrop(args.size),
75+
transforms.ToTensor(),
76+
normalize,
77+
]))
6578

6679
data_loader = DataLoader(dataset,
6780
batch_size=args.batch_size, shuffle=False,
@@ -79,13 +92,19 @@ def extract(data_loader, model, path_file, mode):
7992
path_txt = path_file + '.txt'
8093
hdf5_file = h5py.File(path_hdf5, 'w')
8194

95+
# estimate output shapes
96+
output = model(Variable(torch.ones(1, 3, args.size, args.size),
97+
volatile=True))
98+
8299
nb_images = len(data_loader.dataset)
83100
if mode == 'both' or mode == 'att':
84-
shape_att = (nb_images, 2048, 14, 14)
101+
shape_att = (nb_images, output.size(1), output.size(2), output.size(3))
102+
print('Warning: shape_att={}'.format(shape_att))
85103
hdf5_att = hdf5_file.create_dataset('att', shape_att,
86104
dtype='f')#, compression='gzip')
87105
if mode == 'both' or mode == 'noatt':
88-
shape_noatt = (nb_images, 2048)
106+
shape_noatt = (nb_images, output.size(1))
107+
print('Warning: shape_noatt={}'.format(shape_noatt))
89108
hdf5_noatt = hdf5_file.create_dataset('noatt', shape_noatt,
90109
dtype='f')#, compression='gzip')
91110

@@ -98,7 +117,7 @@ def extract(data_loader, model, path_file, mode):
98117

99118
idx = 0
100119
for i, input in enumerate(data_loader):
101-
input_var = torch.autograd.Variable(input['visual'], volatile=True)
120+
input_var = Variable(input['visual'], volatile=True)
102121
output_att = model(input_var)
103122

104123
nb_regions = output_att.size(2) * output_att.size(3)
@@ -111,6 +130,7 @@ def extract(data_loader, model, path_file, mode):
111130
hdf5_noatt[idx:idx+batch_size] = output_noatt.data.cpu().numpy()
112131
idx += batch_size
113132

133+
torch.cuda.synchronize()
114134
batch_time.update(time.time() - end)
115135
end = time.time()
116136

options/vqa2/default.yaml

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
logs:
2+
dir_logs: logs/vqa2/default
3+
vqa:
4+
dataset: VQA2
5+
dir: data/vqa2
6+
trainsplit: train
7+
nans: 2000
8+
maxlength: 26
9+
minwcount: 0
10+
nlp: mcb
11+
pad: right
12+
samplingans: True
13+
coco:
14+
dir: data/coco
15+
arch: fbresnet152
16+
mode: noatt
17+
size: 448
18+
model:
19+
arch: MLBNoAtt
20+
seq2vec:
21+
arch: skipthoughts
22+
dir_st: data/skip-thoughts
23+
type: UniSkip
24+
dropout: 0.25
25+
fixed_emb: False
26+
fusion:
27+
dim_v: 2048
28+
dim_q: 2400
29+
dim_h: 1200
30+
dropout_v: 0.5
31+
dropout_q: 0.5
32+
activation_v: tanh
33+
activation_q: tanh
34+
classif:
35+
activation: tanh
36+
dropout: 0.5
37+
optim:
38+
lr: 0.0001
39+
batch_size: 512
40+
epochs: 100

options/vqa2/mlb_att_trainval.yaml

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
logs:
2+
dir_logs: logs/vqa2/mlb_att_trainval
3+
vqa:
4+
dataset: VQA2
5+
dir: data/vqa2
6+
trainsplit: trainval
7+
nans: 2000
8+
maxlength: 26
9+
minwcount: 0
10+
nlp: mcb
11+
pad: right
12+
samplingans: True
13+
coco:
14+
dir: data/coco
15+
arch: fbresnet152
16+
mode: att
17+
size: 448
18+
model:
19+
arch: MLBAtt
20+
dim_v: 2048
21+
dim_q: 2400
22+
seq2vec:
23+
arch: skipthoughts
24+
dir_st: data/skip-thoughts
25+
type: BayesianUniSkip
26+
dropout: 0.25
27+
fixed_emb: False
28+
attention:
29+
nb_glimpses: 4
30+
dim_h: 1200
31+
dropout_v: 0.5
32+
dropout_q: 0.5
33+
dropout_mm: 0.5
34+
activation_v: tanh
35+
activation_q: tanh
36+
activation_mm: tanh
37+
fusion:
38+
dim_h: 1200
39+
dropout_v: 0.5
40+
dropout_q: 0.5
41+
activation_v: tanh
42+
activation_q: tanh
43+
classif:
44+
activation: tanh
45+
dropout: 0.5
46+
optim:
47+
lr: 0.0001
48+
batch_size: 128
49+
epochs: 100

0 commit comments

Comments
 (0)