Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update word language model
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jan 22, 2018
1 parent 1eb7e17 commit 4e92a88
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 131 deletions.
66 changes: 0 additions & 66 deletions example/gluon/word_language_model/data.py

This file was deleted.

36 changes: 0 additions & 36 deletions example/gluon/word_language_model/get_wikitext2_data.sh

This file was deleted.

63 changes: 34 additions & 29 deletions example/gluon/word_language_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
import model
import data

parser = argparse.ArgumentParser(description='MXNet Autograd PennTreeBank RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='./data/wikitext-2/wiki.',
help='location of the data corpus')
parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.')
parser.add_argument('--model', type=str, default='lstm',
help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)')
parser.add_argument('--emsize', type=int, default=200,
Expand Down Expand Up @@ -72,26 +70,37 @@
else:
context = mx.cpu(0)

corpus = data.Corpus(args.data)

def batchify(data, batch_size):
"""Reshape data into (num_example, batch_size)"""
nbatch = data.shape[0] // batch_size
data = data[:nbatch * batch_size]
data = data.reshape((batch_size, nbatch)).T
return data

train_data = batchify(corpus.train, args.batch_size).as_in_context(context)
val_data = batchify(corpus.valid, args.batch_size).as_in_context(context)
test_data = batchify(corpus.test, args.batch_size).as_in_context(context)
train_dataset = gluon.data.text.WikiText2('./data', 'train', seq_len=args.bptt)
indexer = train_dataset.indexer
val_dataset, test_dataset = [gluon.data.text.WikiText2('./data', segment,
indexer=indexer,
seq_len=args.bptt)
for segment in ['validation', 'test']]
train_data = gluon.data.DataLoader(train_dataset,
batch_size=args.batch_size,
sampler=gluon.data.IntervalSampler(len(train_dataset),
args.batch_size),
last_batch='discard')

val_data = gluon.data.DataLoader(val_dataset,
batch_size=args.batch_size,
sampler=gluon.data.IntervalSampler(len(val_dataset),
args.batch_size),
last_batch='discard')

test_data = gluon.data.DataLoader(test_dataset,
batch_size=args.batch_size,
sampler=gluon.data.IntervalSampler(len(test_dataset),
args.batch_size),
last_batch='discard')


###############################################################################
# Build the model
###############################################################################


ntokens = len(corpus.dictionary)
ntokens = len(indexer)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
args.nlayers, args.dropout, args.tied)
model.collect_params().initialize(mx.init.Xavier(), ctx=context)
Expand All @@ -108,12 +117,6 @@ def batchify(data, batch_size):
# Training code
###############################################################################

def get_batch(source, i):
seq_len = min(args.bptt, source.shape[0] - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len]
return data, target.reshape((-1,))

def detach(hidden):
if isinstance(hidden, (tuple, list)):
hidden = [i.detach() for i in hidden]
Expand All @@ -125,8 +128,9 @@ def eval(data_source):
total_L = 0.0
ntotal = 0
hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
for i in range(0, data_source.shape[0] - 1, args.bptt):
data, target = get_batch(data_source, i)
for i, (data, target) in enumerate(data_source):
data = data.as_in_context(context).T
target = target.as_in_context(context).T.reshape((-1, 1))
output, hidden = model(data, hidden)
L = loss(output, target)
total_L += mx.nd.sum(L).asscalar()
Expand All @@ -139,26 +143,27 @@ def train():
total_L = 0.0
start_time = time.time()
hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args.bptt)):
data, target = get_batch(train_data, i)
for i, (data, target) in enumerate(train_data):
data = data.as_in_context(context).T
target = target.as_in_context(context).T.reshape((-1, 1))
hidden = detach(hidden)
with autograd.record():
output, hidden = model(data, hidden)
L = loss(output, target)
L.backward()

grads = [i.grad(context) for i in model.collect_params().values()]
grads = [p.grad(context) for p in model.collect_params().values()]
# Here gradient is for the whole batch.
# So we multiply max_norm by batch_size and bptt size to balance it.
gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size)

trainer.step(args.batch_size)
total_L += mx.nd.sum(L).asscalar()

if ibatch % args.log_interval == 0 and ibatch > 0:
if i % args.log_interval == 0 and i > 0:
cur_L = total_L / args.bptt / args.batch_size / args.log_interval
print('[Epoch %d Batch %d] loss %.2f, ppl %.2f'%(
epoch, ibatch, cur_L, math.exp(cur_L)))
epoch, i, cur_L, math.exp(cur_L)))
total_L = 0.0

val_L = eval(val_data)
Expand Down

0 comments on commit 4e92a88

Please sign in to comment.