Skip to content

Commit

Permalink
Language Modeling Datasets and Sampler (apache#9514)
Browse files Browse the repository at this point in the history
* refactor dataset

* add interval sampler

* wikitext-2/-103

* update word language model

* address comments

* move interval sampler to contrib

* update

* add frequencies property
  • Loading branch information
szha authored Jan 30, 2018
1 parent 0ff26df commit 7dcec50
Show file tree
Hide file tree
Showing 10 changed files with 369 additions and 155 deletions.
66 changes: 0 additions & 66 deletions example/gluon/word_language_model/data.py

This file was deleted.

68 changes: 39 additions & 29 deletions example/gluon/word_language_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
import math
import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import contrib
import model
import data

parser = argparse.ArgumentParser(description='MXNet Autograd PennTreeBank RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='./data/wikitext-2/wiki.',
help='location of the data corpus')
parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.')
parser.add_argument('--model', type=str, default='lstm',
help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)')
parser.add_argument('--emsize', type=int, default=200,
Expand Down Expand Up @@ -72,26 +71,41 @@
else:
context = mx.cpu(0)

corpus = data.Corpus(args.data)

def batchify(data, batch_size):
"""Reshape data into (num_example, batch_size)"""
nbatch = data.shape[0] // batch_size
data = data[:nbatch * batch_size]
data = data.reshape((batch_size, nbatch)).T
return data

train_data = batchify(corpus.train, args.batch_size).as_in_context(context)
val_data = batchify(corpus.valid, args.batch_size).as_in_context(context)
test_data = batchify(corpus.test, args.batch_size).as_in_context(context)
train_dataset = contrib.data.text.WikiText2('./data', 'train', seq_len=args.bptt)
vocab = train_dataset.vocabulary
val_dataset, test_dataset = [contrib.data.text.WikiText2('./data', segment,
vocab=vocab,
seq_len=args.bptt)
for segment in ['validation', 'test']]

nbatch_train = len(train_dataset) / args.batch_size
train_data = gluon.data.DataLoader(train_dataset,
batch_size=args.batch_size,
sampler=contrib.data.IntervalSampler(len(train_dataset),
nbatch_train),
last_batch='discard')

nbatch_val = len(val_dataset) / args.batch_size
val_data = gluon.data.DataLoader(val_dataset,
batch_size=args.batch_size,
sampler=contrib.data.IntervalSampler(len(val_dataset),
nbatch_val),
last_batch='discard')

nbatch_test = len(test_dataset) / args.batch_size
test_data = gluon.data.DataLoader(test_dataset,
batch_size=args.batch_size,
sampler=contrib.data.IntervalSampler(len(test_dataset),
nbatch_test),
last_batch='discard')


###############################################################################
# Build the model
###############################################################################


ntokens = len(corpus.dictionary)
ntokens = len(vocab)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
args.nlayers, args.dropout, args.tied)
model.collect_params().initialize(mx.init.Xavier(), ctx=context)
Expand All @@ -108,12 +122,6 @@ def batchify(data, batch_size):
# Training code
###############################################################################

def get_batch(source, i):
seq_len = min(args.bptt, source.shape[0] - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len]
return data, target.reshape((-1,))

def detach(hidden):
if isinstance(hidden, (tuple, list)):
hidden = [i.detach() for i in hidden]
Expand All @@ -125,8 +133,9 @@ def eval(data_source):
total_L = 0.0
ntotal = 0
hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
for i in range(0, data_source.shape[0] - 1, args.bptt):
data, target = get_batch(data_source, i)
for i, (data, target) in enumerate(data_source):
data = data.as_in_context(context).T
target = target.as_in_context(context).T.reshape((-1, 1))
output, hidden = model(data, hidden)
L = loss(output, target)
total_L += mx.nd.sum(L).asscalar()
Expand All @@ -139,26 +148,27 @@ def train():
total_L = 0.0
start_time = time.time()
hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context)
for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args.bptt)):
data, target = get_batch(train_data, i)
for i, (data, target) in enumerate(train_data):
data = data.as_in_context(context).T
target = target.as_in_context(context).T.reshape((-1, 1))
hidden = detach(hidden)
with autograd.record():
output, hidden = model(data, hidden)
L = loss(output, target)
L.backward()

grads = [i.grad(context) for i in model.collect_params().values()]
grads = [p.grad(context) for p in model.collect_params().values()]
# Here gradient is for the whole batch.
# So we multiply max_norm by batch_size and bptt size to balance it.
gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size)

trainer.step(args.batch_size)
total_L += mx.nd.sum(L).asscalar()

if ibatch % args.log_interval == 0 and ibatch > 0:
if i % args.log_interval == 0 and i > 0:
cur_L = total_L / args.bptt / args.batch_size / args.log_interval
print('[Epoch %d Batch %d] loss %.2f, ppl %.2f'%(
epoch, ibatch, cur_L, math.exp(cur_L)))
epoch, i, cur_L, math.exp(cur_L)))
total_L = 0.0

val_L = eval(val_data)
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/gluon/contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@
from . import nn

from . import rnn

from . import data
22 changes: 5 additions & 17 deletions ...word_language_model/get_wikitext2_data.sh → python/mxnet/gluon/contrib/data/__init__.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#!/usr/bin/env bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
Expand All @@ -17,20 +15,10 @@
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import
"""Contrib datasets."""

RNN_DIR=$(cd `dirname $0`; pwd)
DATA_DIR="${RNN_DIR}/data/"

if [[ ! -d "${DATA_DIR}" ]]; then
echo "${DATA_DIR} doesn't exist, will create one";
mkdir -p ${DATA_DIR}
fi

wget -P ${DATA_DIR} https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
cd ${DATA_DIR}
unzip wikitext-2-v1.zip
from . import text

# rename
mv ${DATA_DIR}/wikitext-2/wiki.test.tokens ${DATA_DIR}/wikitext-2/wiki.test.txt
mv ${DATA_DIR}/wikitext-2/wiki.valid.tokens ${DATA_DIR}/wikitext-2/wiki.valid.txt
mv ${DATA_DIR}/wikitext-2/wiki.train.tokens ${DATA_DIR}/wikitext-2/wiki.train.txt
from .sampler import *
22 changes: 22 additions & 0 deletions python/mxnet/gluon/contrib/data/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8

"""Read text files and load embeddings."""

EOS_TOKEN = '<eos>'
62 changes: 62 additions & 0 deletions python/mxnet/gluon/contrib/data/sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=
"""Dataset sampler."""
__all__ = ['IntervalSampler']

from ...data import sampler

class IntervalSampler(sampler.Sampler):
"""Samples elements from [0, length) at fixed intervals.
Parameters
----------
length : int
Length of the sequence.
interval : int
The number of items to skip between two samples.
rollover : bool, default True
Whether to start again from the first skipped item after reaching the end.
If true, this sampler would start again from the first skipped item until all items
are visited.
Otherwise, iteration stops when end is reached and skipped items are ignored.
Examples
--------
>>> sampler = contrib.data.IntervalSampler(13, interval=3)
>>> list(sampler)
[0, 3, 6, 9, 12, 1, 4, 7, 10, 2, 5, 8, 11]
>>> sampler = contrib.data.IntervalSampler(13, interval=3, rollover=False)
>>> list(sampler)
[0, 3, 6, 9, 12]
"""
def __init__(self, length, interval, rollover=True):
assert interval < length, \
"Interval {} must be smaller than length {}".format(interval, length)
self._length = length
self._interval = interval
self._rollover = rollover

def __iter__(self):
for i in range(self._interval if self._rollover else 1):
for j in range(i, self._length, self._interval):
yield j

def __len__(self):
return self._length
Loading

0 comments on commit 7dcec50

Please sign in to comment.