-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlstm.py
177 lines (154 loc) · 6.44 KB
/
lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""
This module contains a custom PyTorch LSTM class and a wrapper for the PyTorch
LSTM layer that deals with packing and unpacking padded inputs.
Author: Steve Bischoff
"""
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
torch.manual_seed(0)
class pack_pad_lstm_wrapper(nn.Module):
def __init__(self, vocab, embedding_size, hidden_size, batch_size=1):
super().__init__()
self.padding_value = vocab['<pad>']
self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True)
def forward(self, embeds, lengths):
# pack padded embeddings using lengths list for lstm layer
packed_embeds = pack_padded_sequence(embeds, lengths, batch_first=True, enforce_sorted=False)
packed_hidden, _ = self.lstm(packed_embeds)
# pad packed hidden layers using vocab padding value
hidden, _ = pad_packed_sequence(packed_hidden, batch_first=True, padding_value=self.padding_value)
return hidden
class LSTM(nn.Module):
def __init__(self, vocab, embedding_size, hidden_size, output_size, batch_size=1):
super().__init__()
self.embedding = nn.Embedding(len(vocab), embedding_size)
self.pack_pad_lstm = pack_pad_lstm_wrapper(vocab, embedding_size, hidden_size)
self.dropout = nn.Dropout(0.2)
self.h_2_o = nn.Linear(hidden_size, output_size)
# activations: log_softmax for training, softmax for getting probabilities
self.log_softmax = nn.LogSoftmax(dim=1)
self.softmax = nn.Softmax(dim=-1)
def forward(self, idx, lengths):
embeds = self.embedding(idx)
hidden = self.pack_pad_lstm(embeds, lengths)
hidden = self.dropout(hidden)
# keep only the last unpadded hidden state using lengths
hidden = torch.stack([h[length-1] for h, length in zip(hidden, lengths)])
output = self.h_2_o(hidden)
output = self.log_softmax(output)
return output
@torch.no_grad()
def evaluate_batch(self, idx, lengths, y, loss_criterion): # assumes self.eval()
"""
Calculate loss and accuracy (# correct and # total) for a batch of inputs.
Assumes self is in evaluation mode (self.eval()).
Params:
idx: tensor of shape (batch_size, max_length)
lengths: list of integers
y: tensor of shape (batch_size)
Returns:
loss: scalar
n_correct: int
n_total: int
"""
outputs = self(idx, lengths)
predictions = outputs.topk(1)[1]
# calculate loss
loss = loss_criterion(outputs, y).item()
# calculate no. of correct predictions
n_correct = torch.sum(predictions == y.view(-1, 1)).item()
return loss, (n_correct, len(idx))
@torch.no_grad()
def evaluate(self, dataloader, loss_criterion):
"""
Calculate loss and accuracy for all inputs in a dataloader.
Params:
dataloader: DataLoader object for IMDBDataset
loss_criterion: torch loss function
Returns:
loss: float
acc: float
"""
# initialize variables
losses = []
n_correct, n_total = 0, 0
# set evaluation mode
self.eval()
# iterate over batches
for i, (idx, lengths, y) in enumerate(dataloader): # iterate over batches
batch_loss, (batch_correct, batch_size) = self.evaluate_batch(idx, lengths, y, loss_criterion)
losses.append(batch_loss)
n_correct += batch_correct
n_total += batch_size
acc = n_correct/n_total
return np.mean(losses), acc
def train_batch(self, idx, lengths, y, loss_criterion, optimizer):
"""
Train the model on a single batch.
Params:
idx: tensor of shape (batch_size, max_length)
lengths: list of integers
y: tensor of shape (batch_size)
loss_criterion: torch loss function
optimizer: torch optimizer
"""
# reset gradients
self.zero_grad()
# get batch model output
output = self(idx, lengths)
# calculate loss
loss = loss_criterion(output, y)
# backpropagate
loss.backward()
optimizer.step()
return loss.item()
def train_epoch(self, dataloader, loss_criterion, optimizer):
"""
Train the model for one epoch.
Params:
dataloader: DataLoader object for IMDBDataset
loss_criterion: torch loss function
optimizer: torch optimizer
"""
# set training mode
self.train()
# iterate over batches
for i, (idx, lengths, y) in enumerate(dataloader):
self.train_batch(idx, lengths, y, loss_criterion, optimizer)
def fit(self,
train_dataloader,
loss_criterion,
optimizer,
epochs=1, # int >= 0
track_train_stats=False, # bool
track_test_stats=False, # bool
test_dataloader=None, # DataLoader object for IMDBDataset
verbose=False # bool
):
"""
Train the model for a specified number of epochs.
Params:
train_dataloader: DataLoader object for IMDBDataset
loss_criterion: torch loss function
optimizer: torch optimizer
"""
# sanity check
assert not track_test_stats or test_dataloader is not None, 'Please set track_test_stats to False or specify test dataloader.'
# iterate over epochs
for epoch in range(epochs):
if verbose:
print('Epoch ' + str(epoch+1))
# run one training epoch
self.train_epoch(train_dataloader, loss_criterion, optimizer)
# optionally track training stats
if track_train_stats:
train_loss, train_accuracy = self.evaluate(train_dataloader, loss_criterion)
if verbose:
print('Train Loss: {0:.3f}. Train Accuracy: {1:.3f}'.format(train_loss, train_accuracy))
# optionally track test stats
if track_test_stats:
test_loss, test_accuracy = self.evaluate(test_dataloader, loss_criterion)
if verbose:
print('Test Loss: {0:.3f}. Test Accuracy: {1:.3f}'.format(test_loss, test_accuracy))