-
Notifications
You must be signed in to change notification settings - Fork 673
/
Copy pathcompliance_kaldi_test.py
76 lines (61 loc) · 3.01 KB
/
compliance_kaldi_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torchaudio.compliance.kaldi as kaldi
from torchaudio_unittest import common_utils
def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
# just a copy of ExtractWindow from feature-window.cc in python
def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
if snip_edges:
return frame * window_shift
else:
midpoint_of_frame = frame * window_shift + window_shift // 2
beginning_of_frame = midpoint_of_frame - window_size // 2
return beginning_of_frame
sample_offset = 0
num_samples = sample_offset + wave.size(0)
start_sample = first_sample_of_frame(f, frame_length, frame_shift, snip_edges)
end_sample = start_sample + frame_length
if snip_edges:
assert(start_sample >= sample_offset and end_sample <= num_samples)
else:
assert(sample_offset == 0 or start_sample >= sample_offset)
wave_start = start_sample - sample_offset
wave_end = wave_start + frame_length
if wave_start >= 0 and wave_end <= wave.size(0):
window[f, :] = wave[wave_start:(wave_start + frame_length)]
else:
wave_dim = wave.size(0)
for s in range(frame_length):
s_in_wave = s + wave_start
while s_in_wave < 0 or s_in_wave >= wave_dim:
if s_in_wave < 0:
s_in_wave = - s_in_wave - 1
else:
s_in_wave = 2 * wave_dim - 1 - s_in_wave
window[f, s] = wave[s_in_wave]
class Test_Kaldi(common_utils.TempDirMixin, common_utils.TorchaudioTestCase):
def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges):
waveform = torch.arange(num_samples).float()
output = kaldi._get_strided(waveform, window_size, window_shift, snip_edges)
# from NumFrames in feature-window.cc
n = window_size
if snip_edges:
m = 0 if num_samples < window_size else 1 + (num_samples - window_size) // window_shift
else:
m = (num_samples + (window_shift // 2)) // window_shift
self.assertTrue(output.dim() == 2)
self.assertTrue(output.shape[0] == m and output.shape[1] == n)
window = torch.empty((m, window_size))
for r in range(m):
extract_window(window, waveform, r, window_size, window_shift, snip_edges)
self.assertEqual(window, output)
def test_get_strided(self):
# generate any combination where 0 < window_size <= num_samples and
# 0 < window_shift.
for num_samples in range(1, 20):
for window_size in range(1, num_samples + 1):
for window_shift in range(1, 2 * num_samples + 1):
for snip_edges in range(0, 2):
self._test_get_strided_helper(num_samples, window_size, window_shift, snip_edges)
def test_mfcc_empty(self):
# Passing in an empty tensor should result in an error
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))