diff --git a/example/gluon/word_language_model/train.py b/example/gluon/word_language_model/train.py index ff20dfecea46..001e9f4930e7 100644 --- a/example/gluon/word_language_model/train.py +++ b/example/gluon/word_language_model/train.py @@ -20,6 +20,7 @@ import math import mxnet as mx from mxnet import gluon, autograd +from mxnet.gluon import contrib import model import data @@ -70,32 +71,32 @@ else: context = mx.cpu(0) -train_dataset = gluon.data.text.WikiText2('./data', 'train', seq_len=args.bptt) -indexer = train_dataset.indexer -val_dataset, test_dataset = [gluon.data.text.WikiText2('./data', segment, - indexer=indexer, - seq_len=args.bptt) +train_dataset = contrib.data.text.WikiText2('./data', 'train', seq_len=args.bptt) +vocab = train_dataset.vocabulary +val_dataset, test_dataset = [contrib.data.text.WikiText2('./data', segment, + vocab=vocab, + seq_len=args.bptt) for segment in ['validation', 'test']] nbatch_train = len(train_dataset) / args.batch_size train_data = gluon.data.DataLoader(train_dataset, batch_size=args.batch_size, - sampler=gluon.data.IntervalSampler(len(train_dataset), - nbatch_train), + sampler=contrib.data.IntervalSampler(len(train_dataset), + nbatch_train), last_batch='discard') nbatch_val = len(val_dataset) / args.batch_size val_data = gluon.data.DataLoader(val_dataset, batch_size=args.batch_size, - sampler=gluon.data.IntervalSampler(len(val_dataset), - nbatch_val), + sampler=contrib.data.IntervalSampler(len(val_dataset), + nbatch_val), last_batch='discard') nbatch_test = len(test_dataset) / args.batch_size test_data = gluon.data.DataLoader(test_dataset, batch_size=args.batch_size, - sampler=gluon.data.IntervalSampler(len(test_dataset), - nbatch_test), + sampler=contrib.data.IntervalSampler(len(test_dataset), + nbatch_test), last_batch='discard') @@ -104,7 +105,7 @@ ############################################################################### -ntokens = len(indexer) +ntokens = len(vocab) model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied) model.collect_params().initialize(mx.init.Xavier(), ctx=context) diff --git a/python/mxnet/gluon/contrib/data/_constants.py b/python/mxnet/gluon/contrib/data/_constants.py new file mode 100644 index 000000000000..86974dce4502 --- /dev/null +++ b/python/mxnet/gluon/contrib/data/_constants.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 + +"""Read text files and load embeddings.""" + +EOS_TOKEN = '' diff --git a/python/mxnet/gluon/contrib/data/text.py b/python/mxnet/gluon/contrib/data/text.py index 2c76ec0613a7..2dff4b3119ad 100644 --- a/python/mxnet/gluon/contrib/data/text.py +++ b/python/mxnet/gluon/contrib/data/text.py @@ -26,73 +26,45 @@ import shutil import numpy as np +from . import _constants as C from ...data import dataset from ...utils import download, check_sha1 from ....contrib import text from .... import nd -class WikiText2(dataset._DownloadedDataset): - """WikiText-2 word-level dataset for language modeling, from Salesforce research. +class _TextDataset(dataset._DownloadedDataset): # pylint: disable=abstract-method + def __init__(self, repo_dir, root, vocabulary, transform): + self._vocab = vocabulary + super(_TextDataset, self).__init__(repo_dir, root, transform) - From - https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset - - License: Creative Commons Attribution-ShareAlike + @property + def vocabulary(self): + return self._vocab - Each sample is a vector of length equal to the specified sequence length. - At the end of each sentence, an end-of-sentence token '' is added. - - Parameters - ---------- - root : str, default '~/.mxnet/datasets/cifar10' - Path to temp folder for storing data. - segment : str, default 'train' - Dataset segment. Options are 'train', 'validation', 'test'. - indexer : :class:`~mxnet.contrib.text.indexer.TokenIndexer`, default None - The indexer to use for indexing the text dataset. If None, a default indexer is created. - seq_len : int, default 35 - The sequence length of each sample, regardless of the sentence boundary. - transform : function, default None - A user defined callback that transforms each sample. For example: - :: - transform=lambda data, label: (data.astype(np.float32)/255, label) +class _WikiText(_TextDataset): - """ - def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'), - segment='train', indexer=None, seq_len=35, transform=None): - self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe') - self._data_file = {'train': ('wiki.train.tokens', - '863f29c46ef9d167fff4940ec821195882fe29d1'), - 'validation': ('wiki.valid.tokens', - '0418625c8b4da6e4b5c7a0b9e78d4ae8f7ee5422'), - 'test': ('wiki.test.tokens', - 'c7b8ce0aa086fb34dab808c5c49224211eb2b172')} - self._segment = segment - self._seq_len = seq_len - self.indexer = indexer - super(WikiText2, self).__init__('wikitext-2', root, transform) + def _build_vocab(self, content): + if not self._vocab: + counter = text.utils.count_tokens_from_str(content) + self._vocab = text.vocab.Vocabulary(counter=counter, + reserved_tokens=[C.EOS_TOKEN]) def _read_batch(self, filename): with io.open(filename, 'r', encoding='utf8') as fin: content = fin.read() - eos_token = '' - if not self.indexer: - counter = text.utils.count_tokens_from_str(content) - self.indexer = text.indexer.TokenIndexer(counter=counter, - reserved_tokens=[eos_token]) + self._build_vocab(content) raw_data = [line for line in [x.strip().split() for x in content.splitlines()] if line] for line in raw_data: - line.append(eos_token) + line.append(C.EOS_TOKEN) raw_data = [x for x in line for line in raw_data if x] - raw_data = self.indexer.to_indices(raw_data) + raw_data = self.vocabulary.to_indices(raw_data) data = raw_data[0:-1] label = raw_data[1:] return np.array(data, dtype=np.int32), np.array(label, dtype=np.int32) - def _get_data(self): archive_file_name, archive_hash = self._archive_file data_file_name, data_hash = self._data_file[self._segment] @@ -117,7 +89,50 @@ def _get_data(self): self._label = nd.array(label, dtype=label.dtype).reshape((-1, self._seq_len)) -class WikiText103(WikiText2): +class WikiText2(_WikiText): + """WikiText-2 word-level dataset for language modeling, from Salesforce research. + + From + https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset + + License: Creative Commons Attribution-ShareAlike + + Each sample is a vector of length equal to the specified sequence length. + At the end of each sentence, an end-of-sentence token '' is added. + + Parameters + ---------- + root : str, default '~/.mxnet/datasets/cifar10' + Path to temp folder for storing data. + segment : str, default 'train' + Dataset segment. Options are 'train', 'validation', 'test'. + vocab : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None + The vocabulary to use for indexing the text dataset. + If None, a default vocabulary is created. + seq_len : int, default 35 + The sequence length of each sample, regardless of the sentence boundary. + transform : function, default None + A user defined callback that transforms each sample. For example: + :: + + transform=lambda data, label: (data.astype(np.float32)/255, label) + + """ + def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'), + segment='train', vocab=None, seq_len=35, transform=None): + self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe') + self._data_file = {'train': ('wiki.train.tokens', + '863f29c46ef9d167fff4940ec821195882fe29d1'), + 'validation': ('wiki.valid.tokens', + '0418625c8b4da6e4b5c7a0b9e78d4ae8f7ee5422'), + 'test': ('wiki.test.tokens', + 'c7b8ce0aa086fb34dab808c5c49224211eb2b172')} + self._segment = segment + self._seq_len = seq_len + super(WikiText2, self).__init__('wikitext-2', root, vocab, transform) + + +class WikiText103(_WikiText): """WikiText-103 word-level dataset for language modeling, from Salesforce research. From @@ -134,8 +149,9 @@ class WikiText103(WikiText2): Path to temp folder for storing data. segment : str, default 'train' Dataset segment. Options are 'train', 'validation', 'test'. - indexer : :class:`~mxnet.contrib.text.indexer.TokenIndexer`, default None - The indexer to use for indexing the text dataset. If None, a default indexer is created. + vocab : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None + The vocabulary to use for indexing the text dataset. + If None, a default vocabulary is created. seq_len : int, default 35 The sequence length of each sample, regardless of the sentence boundary. transform : function, default None @@ -146,7 +162,7 @@ class WikiText103(WikiText2): """ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'), - segment='train', indexer=None, seq_len=35, transform=None): + segment='train', vocab=None, seq_len=35, transform=None): self._archive_file = ('wikitext-103-v1.zip', '0aec09a7537b58d4bb65362fee27650eeaba625a') self._data_file = {'train': ('wiki.train.tokens', 'b7497e2dfe77e72cfef5e3dbc61b7b53712ac211'), @@ -156,5 +172,4 @@ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'), '8a5befc548865cec54ed4273cf87dbbad60d1e47')} self._segment = segment self._seq_len = seq_len - self.indexer = indexer - super(WikiText2, self).__init__('wikitext-103', root, transform) # pylint: disable=bad-super-call + super(WikiText103, self).__init__('wikitext-103', root, vocab, transform)