-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathLSTM.py
87 lines (76 loc) · 3.41 KB
/
LSTM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
'''
https://github.com/keitakurita/Better_LSTM_PyTorch/blob/master/better_lstm/model.py
Add to net.py: self.lstm = LSTM(1+params.cov_dim+params.embedding_dim, params.lstm_hidden_dim, params.lstm_layers, bias = True,
batch_first = False, dropout = params.lstm_dropout)
'''
import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence
from typing import *
class VariationalDropout(nn.Module):
"""
Applies the same dropout mask across the temporal dimension
See https://arxiv.org/abs/1512.05287 for more details.
Note that this is not applied to the recurrent activations in the LSTM like the above paper.
Instead, it is applied to the inputs and outputs of the recurrent layer.
"""
def __init__(self, dropout: float, batch_first: Optional[bool]=False):
super().__init__()
self.dropout = dropout
self.batch_first = batch_first
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.training or self.dropout <= 0.:
return x
is_packed = isinstance(x, PackedSequence)
if is_packed:
x, batch_sizes = x
max_batch_size = int(batch_sizes[0])
else:
batch_sizes = None
max_batch_size = x.size(0)
# Drop same mask across entire sequence
if self.batch_first:
m = x.new_empty(max_batch_size, 1, x.size(2), requires_grad=False).bernoulli_(1 - self.dropout)
else:
m = x.new_empty(1, max_batch_size, x.size(2), requires_grad=False).bernoulli_(1 - self.dropout)
x = x.masked_fill(m == 0, 0) / (1 - self.dropout)
if is_packed:
return PackedSequence(x, batch_sizes)
else:
return x
class LSTM(nn.LSTM):
def __init__(self, *args, dropouti: float=0.,
dropoutw: float=0., dropouto: float=0.,
batch_first=True, unit_forget_bias=True, **kwargs):
super().__init__(*args, **kwargs, batch_first=batch_first)
self.unit_forget_bias = unit_forget_bias
self.dropoutw = dropoutw
self.input_drop = VariationalDropout(dropouti,
batch_first=batch_first)
self.output_drop = VariationalDropout(dropouto,
batch_first=batch_first)
self._init_weights()
def _init_weights(self):
"""
Use orthogonal init for recurrent layers, xavier uniform for input layers
Bias is 0 except for forget gate
"""
for name, param in self.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param.data)
elif "weight_ih" in name:
nn.init.xavier_uniform_(param.data)
elif "bias" in name and self.unit_forget_bias:
nn.init.zeros_(param.data)
param.data[self.hidden_size:2 * self.hidden_size] = 1
def _drop_weights(self):
for name, param in self.named_parameters():
if "weight_hh" in name:
getattr(self, name).data = \
torch.nn.functional.dropout(param.data, p=self.dropoutw,
training=self.training).contiguous()
def forward(self, input, hx=None):
self._drop_weights()
input = self.input_drop(input)
seq, state = super().forward(input, hx=hx)
return self.output_drop(seq), state