-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwavenet_modules.py
129 lines (103 loc) · 4.52 KB
/
wavenet_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Attributed to https://github.com/vincentherrmann/pytorch-wavenet/blob/master/wavenet_modules.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch.autograd import Variable, Function
import numpy as np
def dilate(x, dilation, init_dilation=1, pad_start=True):
"""
:param x: Tensor of size (N, C, L), where N is the input dilation, C is the number of channels, and L is the input length
:param dilation: Target dilation. Will be the size of the first dimension of the output tensor.
:param pad_start: If the input length is not compatible with the specified dilation, zero padding is used. This parameter determines wether the zeros are added at the start or at the end.
:return: The dilated tensor of size (dilation, C, L*N / dilation). The output might be zero padded at the start
"""
[n, c, l] = x.size()
dilation_factor = dilation / init_dilation
if dilation_factor == 1:
return x
# zero padding for reshaping
new_l = int(np.ceil(l / dilation_factor) * dilation_factor)
if new_l != l:
l = new_l
x = constant_pad_1d(x, new_l, dimension=2, pad_start=pad_start)
l_old = int(round(l / dilation_factor))
n_old = int(round(n * dilation_factor))
l = math.ceil(l * init_dilation / dilation)
n = math.ceil(n * dilation / init_dilation)
# reshape according to dilation
x = x.permute(1, 2, 0).contiguous() # (n, c, l) -> (c, l, n)
x = x.view(c, l, n)
x = x.permute(2, 0, 1).contiguous() # (c, l, n) -> (n, c, l)
return x
class DilatedQueue:
def __init__(self, max_length, data=None, dilation=1, num_deq=1, num_channels=1, dtype=torch.FloatTensor):
self.in_pos = 0
self.out_pos = 0
self.num_deq = num_deq
self.num_channels = num_channels
self.dilation = dilation
self.max_length = max_length
self.data = data
self.dtype = dtype
if data == None:
self.data = Variable(dtype(num_channels, max_length).zero_())
def enqueue(self, input):
self.data[:, self.in_pos] = input
self.in_pos = (self.in_pos + 1) % self.max_length
def dequeue(self, num_deq=1, dilation=1):
# |
# |6|7|8|1|2|3|4|5|
# |
start = self.out_pos - ((num_deq - 1) * dilation)
if start < 0:
t1 = self.data[:, start::dilation]
t2 = self.data[:, self.out_pos % dilation:self.out_pos + 1:dilation]
t = torch.cat((t1, t2), 1)
else:
t = self.data[:, start:self.out_pos + 1:dilation]
self.out_pos = (self.out_pos + 1) % self.max_length
return t
def reset(self):
self.data = Variable(self.dtype(self.num_channels, self.max_length).zero_())
self.in_pos = 0
self.out_pos = 0
class ConstantPad1d(nn.Module):
def __init__(self, target_size, dimension=0, value=0, pad_start=False):
super(ConstantPad1d, self).__init__()
self.target_size = target_size
self.dimension = dimension
self.value = value
self.pad_start = pad_start
def forward(self, input):
self.num_pad = self.target_size - input.size(self.dimension)
assert self.num_pad >= 0, 'target size has to be greater than input size'
self.input_size = input.size()
size = list(input.size())
size[self.dimension] = self.target_size
output = input.new(*tuple(size)).fill_(self.value)
c_output = output
# crop output
if self.pad_start:
c_output = c_output.narrow(self.dimension, self.num_pad, c_output.size(self.dimension) - self.num_pad)
else:
c_output = c_output.narrow(self.dimension, 0, c_output.size(self.dimension) - self.num_pad)
c_output.copy_(input)
return output
def backward(self, grad_output):
grad_input = grad_output.new(*self.input_size).zero_()
cg_output = grad_output
# crop grad_output
if self.pad_start:
cg_output = cg_output.narrow(self.dimension, self.num_pad, cg_output.size(self.dimension) - self.num_pad)
else:
cg_output = cg_output.narrow(self.dimension, 0, cg_output.size(self.dimension) - self.num_pad)
grad_input.copy_(cg_output)
return grad_input
def constant_pad_1d(input,
target_size,
dimension=0,
value=0,
pad_start=False):
return ConstantPad1d(target_size, dimension, value, pad_start)(input)