-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathcompute_cmvn_stats.py
executable file
·141 lines (119 loc) · 4.91 KB
/
compute_cmvn_stats.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python3
# encoding: utf-8
import sys
import argparse
import json
import codecs
import yaml
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from torch.utils.data import Dataset, DataLoader
class CollateFunc(object):
''' Collate function for AudioDataset
'''
def __init__(self, feat_dim, resample_rate):
self.feat_dim = feat_dim
self.resample_rate = resample_rate
pass
def __call__(self, batch):
mean_stat = torch.zeros(self.feat_dim)
var_stat = torch.zeros(self.feat_dim)
number = 0
for item in batch:
value = item[1].strip().split(",")
assert len(value) == 3 or len(value) == 1
wav_path = value[0]
sample_rate = torchaudio.info(wav_path).sample_rate
resample_rate = sample_rate
# len(value) == 3 means segmented wav.scp,
# len(value) == 1 means original wav.scp
if len(value) == 3:
start_frame = int(float(value[1]) * sample_rate)
end_frame = int(float(value[2]) * sample_rate)
waveform, sample_rate = torchaudio.load(
filepath=wav_path,
num_frames=end_frame - start_frame,
frame_offset=start_frame)
else:
waveform, sample_rate = torchaudio.load(item[1])
waveform = waveform * (1 << 15)
if self.resample_rate != 0 and self.resample_rate != sample_rate:
resample_rate = self.resample_rate
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
mat = kaldi.fbank(waveform,
num_mel_bins=self.feat_dim,
dither=0.0,
energy_floor=0.0,
sample_frequency=resample_rate)
mean_stat += torch.sum(mat, axis=0)
var_stat += torch.sum(torch.square(mat), axis=0)
number += mat.shape[0]
return number, mean_stat, var_stat
class AudioDataset(Dataset):
def __init__(self, data_file):
self.items = []
with codecs.open(data_file, 'r', encoding='utf-8') as f:
for line in f:
arr = line.strip().split()
self.items.append((arr[0], arr[1]))
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
return self.items[idx]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='extract CMVN stats')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for processing')
parser.add_argument('--train_config',
default='',
help='training yaml conf')
parser.add_argument('--in_scp', default=None, help='wav scp file')
parser.add_argument('--out_cmvn',
default='global_cmvn',
help='global cmvn file')
doc = "Print log after every log_interval audios are processed."
parser.add_argument("--log_interval", type=int, default=1000, help=doc)
args = parser.parse_args()
with open(args.train_config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
feat_dim = configs['dataset_conf']['fbank_conf']['num_mel_bins']
resample_rate = 0
if 'resample_conf' in configs['dataset_conf']:
resample_rate = configs['dataset_conf']['resample_conf'][
'resample_rate']
print('using resample and new sample rate is {}'.format(resample_rate))
collate_func = CollateFunc(feat_dim, resample_rate)
dataset = AudioDataset(args.in_scp)
batch_size = 20
data_loader = DataLoader(dataset,
batch_size=batch_size,
shuffle=True,
sampler=None,
num_workers=args.num_workers,
collate_fn=collate_func)
with torch.no_grad():
all_number = 0
all_mean_stat = torch.zeros(feat_dim)
all_var_stat = torch.zeros(feat_dim)
wav_number = 0
for i, batch in enumerate(data_loader):
number, mean_stat, var_stat = batch
all_mean_stat += mean_stat
all_var_stat += var_stat
all_number += number
wav_number += batch_size
if wav_number % args.log_interval == 0:
print(f'processed {wav_number} wavs, {all_number} frames',
file=sys.stderr,
flush=True)
cmvn_info = {
'mean_stat': list(all_mean_stat.tolist()),
'var_stat': list(all_var_stat.tolist()),
'frame_num': all_number
}
with open(args.out_cmvn, 'w') as fout:
fout.write(json.dumps(cmvn_info))