forked from facebookresearch/selavi
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathopt.py
153 lines (143 loc) · 8.12 KB
/
opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
def parse_arguments():
def str2bool(v):
v = v.lower()
if v in ('yes', 'true', 't', '1'):
return True
elif v in ('no', 'false', 'f', '0'):
return False
raise ValueError('Boolean argument needs to be true or false. '
'Instead, it is %s.' % v)
parser = argparse.ArgumentParser(description="Implementation of SwAV")
parser.register('type', 'bool', str2bool)
#########################
#### data parameters ####
#########################
parser.add_argument("--ds_name", type=str, default="kinetics",
choices=['kinetics', 'vggsound', 'kinetics_sound', 'ave', 'ucf101', 'hmdb51'],
help="name of dataset")
parser.add_argument("--root_dir", type=str, default="/path/to/dataset",
help="root dir of dataset")
parser.add_argument("--data_path", type=str, default="datasets/data",
help="path to store dataset pkl files")
parser.add_argument("--num_data_samples", type=int, default=None,
help="number of dataset samples")
parser.add_argument("--num_frames", type=int, default=30,
help="number of frames to sample per clip")
parser.add_argument("--target_fps", type=int, default=30,
help="video fps")
parser.add_argument("--sample_rate", type=int, default=1,
help="rate to sample frames")
parser.add_argument("--num_train_clips", type=int, default=1,
help="number of clips to sample per videos")
parser.add_argument("--train_crop_size", type=int, default=112,
help="train crop size")
parser.add_argument("--test_crop_size", type=int, default=112,
help="test crop size")
parser.add_argument('--colorjitter', type='bool', default='False',
help='use color jitter')
parser.add_argument('--use_grayscale', type='bool', default='False',
help='use grayscale augmentation')
parser.add_argument('--use_gaussian', type='bool', default='False',
help='use gaussian augmentation')
parser.add_argument("--num_sec_aud", type=int, default=1,
help="number of seconds of audio")
parser.add_argument("--aud_sample_rate", type=int, default=48000,
help="audio sample rate")
parser.add_argument("--aud_spec_type", type=int, default=2,
help="audio spec type")
parser.add_argument('--use_volume_jittering', type='bool', default='False',
help='use volume jittering')
parser.add_argument('--use_audio_temp_jittering', type='bool', default='False',
help='use audio temporal jittering')
parser.add_argument('--z_normalize', type='bool', default='False',
help='z-normalize the audio')
parser.add_argument('--dual_data', type='bool', default='False',
help='sample two clips per video')
#########################
#### optim parameters ###
#########################
parser.add_argument("--epochs", default=100, type=int,
help="number of total epochs to run")
parser.add_argument("--batch_size", default=16, type=int,
help="batch size per gpu, i.e. how many unique instances per gpu")
parser.add_argument("--base_lr", default=4.8, type=float, help="base learning rate")
parser.add_argument("--wd", default=1e-6, type=float, help="weight decay")
parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs")
parser.add_argument("--use_warmup_scheduler" , default='True', type='bool',
help="use warmup scheduler")
parser.add_argument("--use_lr_scheduler" , default='False', type='bool',
help="use cosine LR scheduler")
#########################
#### SK parameters ###
#########################
parser.add_argument('--schedulepower', default=1.5, type=float,
help='SK schedule power compared to linear (default: 1.5)')
parser.add_argument('--nopts', default=100, type=int,
help='number of pseudo-opts (default: 100)')
parser.add_argument('--lamb', default=20, type=int,
help='for pseudoopt: lambda (default:25) ')
parser.add_argument('--dist', default=None, type=int,
help='use for distribution')
parser.add_argument('--diff_dist_every', default='False', type='bool',
help='use a different Gaussian at every SK-iter?')
parser.add_argument('--diff_dist_per_head', default='True', type='bool',
help='use a different Gaussian for every head?')
#########################
#### Selavi parameters ###
#########################
parser.add_argument('--ind_groups', default=1, type=int,
help='number of independent groups (default: 100)')
parser.add_argument('--gauss_sd', default=0.1, type=float,
help='sd')
parser.add_argument('--match', default='True', type='bool',
help='match distributions at beginning of training')
parser.add_argument('--distribution', default='default', type=str,
help='distribution of SK-clustering', choices=['gauss', 'default', 'zipf'])
#########################
#### dist parameters ###
#########################
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up distributed
training; see https://pytorch.org/docs/stable/distributed.html""")
parser.add_argument("--world_size", default=-1, type=int, help="""
number of processes: it is set automatically and
should not be passed as argument""")
parser.add_argument("--rank", default=0, type=int, help="""rank of this process:
it is set automatically and should not be passed as argument""")
parser.add_argument("--local_rank", default=0, type=int,
help="this argument is not used and should be ignored")
parser.add_argument("--bash", action='store_true', help="slrum bash mode")
parser.add_argument("--resume", default='False', type='bool', help="slrum bash mode")
#########################
#### model parameters ###
#########################
parser.add_argument("--vid_base_arch", default="r2plus1d_18", type=str,
help="video architecture", choices=['r2plus1d_18'])
parser.add_argument("--aud_base_arch", default="resnet9", type=str,
help="audio architecture", choices=['resnet9', 'resnet18'])
parser.add_argument('--use_mlp', type='bool', default='True',
help='use MLP head')
parser.add_argument("--mlp_dim", default=256, type=int,
help="final layer dimension in projection head")
parser.add_argument("--headcount", default=1, type=int,
help="number of heads")
#########################
#### other parameters ###
#########################
parser.add_argument("--workers", default=10, type=int,
help="number of data loading workers")
parser.add_argument("--checkpoint_freq", type=int, default=5,
help="Save the model periodically")
parser.add_argument("--use_fp16", type='bool', default='False',
help="whether to train with mixed precision or not")
parser.add_argument("--sync_bn", type=str, default="pytorch", help="synchronize bn")
parser.add_argument("--dump_path", type=str, default=".",
help="experiment dump path for checkpoints and log")
parser.add_argument("--seed", type=int, default=31, help="seed")
return parser