-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcnn8rnn_w2vmean_random.yaml
103 lines (93 loc) · 2.68 KB
/
cnn8rnn_w2vmean_random.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
experiment_path: experiments/audiocaps/phrase_level/cnn8rnn_w2vmean_dp_ls/sampling_random/seed_1
seed: 1
data:
train:
dataset:
type: datasets.multi_phrase_dataset.AudioSamplePhrasesDataset
args:
label: data/audiocaps/train.json
waveform: data/audiocaps/waveform.csv
phrase_num: 32
neg_samp_stratg: random
fix_neg: False
collate_fn:
type: datasets.collate_function.TextCollate
args:
pad_keys: [waveform]
text_key: phrases
tokenizer:
type: datasets.text_tokenizer.DictTokenizer
args:
vocabulary: data/audiocaps/vocab.pkl
dataloader_args:
batch_size: 32
num_workers: 4
val:
dataset:
type: datasets.multi_phrase_dataset.AudioSamplePhrasesDataset
args:
label: data/audiocaps/val.json
waveform: data/audiocaps/waveform.csv
phrase_num: 32
neg_samp_stratg: random
fix_neg: False
collate_fn:
type: datasets.collate_function.TextCollate
args:
pad_keys: [waveform]
text_key: phrases
tokenizer:
type: datasets.text_tokenizer.DictTokenizer
args:
vocabulary: data/audiocaps/vocab.pkl
dataloader_args:
batch_size: 32
num_workers: 4
model:
audio_encoder:
type: models.audio_encoder.Cnn8_Rnn
args:
sample_rate: 32000
freeze_cnn: False
freeze_bn: False
pretrained: experiments/pretrained_audio_encoder/audioset_strong_cnn8rnn.pth
text_encoder:
type: models.text_encoder.EmbeddingAgg
args:
vocab_size: 5221
embed_dim: 512
match_fn:
type: models.match.DotProduct
args: {}
type: models.audio_text_model.MultiTextBiEncoder
args:
shared_dim: 512
text_forward_keys: [text, text_len]
add_proj: False
pooling: linear_softmax
loss:
type: losses.ClipBceLoss
args: {}
optimizer:
type: torch.optim.Adam
args:
lr: 0.001
lr_scheduler:
type: torch.optim.lr_scheduler.ReduceLROnPlateau
args:
mode: min
patience: 3
eval_config:
n_thresholds: 50
inference_args:
window_size: 1
trainer:
metric_monitor:
mode: min
name: loss
epochs: 100
early_stop: 10
lr_update_interval: epoch
save_interval: 10
include_optim_in_ckpt: False
max_grad_norm: 1.0