Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jan 25, 2018
1 parent 12891d4 commit 2d42d83
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 63 deletions.
25 changes: 13 additions & 12 deletions example/gluon/word_language_model/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import math
import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import contrib
import model
import data

Expand Down Expand Up @@ -70,32 +71,32 @@
else:
context = mx.cpu(0)

train_dataset = gluon.data.text.WikiText2('./data', 'train', seq_len=args.bptt)
indexer = train_dataset.indexer
val_dataset, test_dataset = [gluon.data.text.WikiText2('./data', segment,
indexer=indexer,
seq_len=args.bptt)
train_dataset = contrib.data.text.WikiText2('./data', 'train', seq_len=args.bptt)
vocab = train_dataset.vocabulary
val_dataset, test_dataset = [contrib.data.text.WikiText2('./data', segment,
vocab=vocab,
seq_len=args.bptt)
for segment in ['validation', 'test']]

nbatch_train = len(train_dataset) / args.batch_size
train_data = gluon.data.DataLoader(train_dataset,
batch_size=args.batch_size,
sampler=gluon.data.IntervalSampler(len(train_dataset),
nbatch_train),
sampler=contrib.data.IntervalSampler(len(train_dataset),
nbatch_train),
last_batch='discard')

nbatch_val = len(val_dataset) / args.batch_size
val_data = gluon.data.DataLoader(val_dataset,
batch_size=args.batch_size,
sampler=gluon.data.IntervalSampler(len(val_dataset),
nbatch_val),
sampler=contrib.data.IntervalSampler(len(val_dataset),
nbatch_val),
last_batch='discard')

nbatch_test = len(test_dataset) / args.batch_size
test_data = gluon.data.DataLoader(test_dataset,
batch_size=args.batch_size,
sampler=gluon.data.IntervalSampler(len(test_dataset),
nbatch_test),
sampler=contrib.data.IntervalSampler(len(test_dataset),
nbatch_test),
last_batch='discard')


Expand All @@ -104,7 +105,7 @@
###############################################################################


ntokens = len(indexer)
ntokens = len(vocab)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid,
args.nlayers, args.dropout, args.tied)
model.collect_params().initialize(mx.init.Xavier(), ctx=context)
Expand Down
22 changes: 22 additions & 0 deletions python/mxnet/gluon/contrib/data/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8

"""Read text files and load embeddings."""

EOS_TOKEN = '<eos>'
117 changes: 66 additions & 51 deletions python/mxnet/gluon/contrib/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,73 +26,45 @@
import shutil
import numpy as np

from . import _constants as C
from ...data import dataset
from ...utils import download, check_sha1
from ....contrib import text
from .... import nd


class WikiText2(dataset._DownloadedDataset):
"""WikiText-2 word-level dataset for language modeling, from Salesforce research.
class _TextDataset(dataset._DownloadedDataset): # pylint: disable=abstract-method
def __init__(self, repo_dir, root, vocabulary, transform):
self._vocab = vocabulary
super(_TextDataset, self).__init__(repo_dir, root, transform)

From
https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset
License: Creative Commons Attribution-ShareAlike
@property
def vocabulary(self):
return self._vocab

Each sample is a vector of length equal to the specified sequence length.
At the end of each sentence, an end-of-sentence token '<eos>' is added.
Parameters
----------
root : str, default '~/.mxnet/datasets/cifar10'
Path to temp folder for storing data.
segment : str, default 'train'
Dataset segment. Options are 'train', 'validation', 'test'.
indexer : :class:`~mxnet.contrib.text.indexer.TokenIndexer`, default None
The indexer to use for indexing the text dataset. If None, a default indexer is created.
seq_len : int, default 35
The sequence length of each sample, regardless of the sentence boundary.
transform : function, default None
A user defined callback that transforms each sample. For example:
::

transform=lambda data, label: (data.astype(np.float32)/255, label)
class _WikiText(_TextDataset):

"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'),
segment='train', indexer=None, seq_len=35, transform=None):
self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
self._data_file = {'train': ('wiki.train.tokens',
'863f29c46ef9d167fff4940ec821195882fe29d1'),
'validation': ('wiki.valid.tokens',
'0418625c8b4da6e4b5c7a0b9e78d4ae8f7ee5422'),
'test': ('wiki.test.tokens',
'c7b8ce0aa086fb34dab808c5c49224211eb2b172')}
self._segment = segment
self._seq_len = seq_len
self.indexer = indexer
super(WikiText2, self).__init__('wikitext-2', root, transform)
def _build_vocab(self, content):
if not self._vocab:
counter = text.utils.count_tokens_from_str(content)
self._vocab = text.vocab.Vocabulary(counter=counter,
reserved_tokens=[C.EOS_TOKEN])

def _read_batch(self, filename):
with io.open(filename, 'r', encoding='utf8') as fin:
content = fin.read()
eos_token = '<eos>'
if not self.indexer:
counter = text.utils.count_tokens_from_str(content)
self.indexer = text.indexer.TokenIndexer(counter=counter,
reserved_tokens=[eos_token])
self._build_vocab(content)
raw_data = [line for line in [x.strip().split() for x in content.splitlines()]
if line]
for line in raw_data:
line.append(eos_token)
line.append(C.EOS_TOKEN)
raw_data = [x for x in line for line in raw_data if x]
raw_data = self.indexer.to_indices(raw_data)
raw_data = self.vocabulary.to_indices(raw_data)
data = raw_data[0:-1]
label = raw_data[1:]
return np.array(data, dtype=np.int32), np.array(label, dtype=np.int32)


def _get_data(self):
archive_file_name, archive_hash = self._archive_file
data_file_name, data_hash = self._data_file[self._segment]
Expand All @@ -117,7 +89,50 @@ def _get_data(self):
self._label = nd.array(label, dtype=label.dtype).reshape((-1, self._seq_len))


class WikiText103(WikiText2):
class WikiText2(_WikiText):
"""WikiText-2 word-level dataset for language modeling, from Salesforce research.
From
https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset
License: Creative Commons Attribution-ShareAlike
Each sample is a vector of length equal to the specified sequence length.
At the end of each sentence, an end-of-sentence token '<eos>' is added.
Parameters
----------
root : str, default '~/.mxnet/datasets/cifar10'
Path to temp folder for storing data.
segment : str, default 'train'
Dataset segment. Options are 'train', 'validation', 'test'.
vocab : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
The vocabulary to use for indexing the text dataset.
If None, a default vocabulary is created.
seq_len : int, default 35
The sequence length of each sample, regardless of the sentence boundary.
transform : function, default None
A user defined callback that transforms each sample. For example:
::
transform=lambda data, label: (data.astype(np.float32)/255, label)
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'),
segment='train', vocab=None, seq_len=35, transform=None):
self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
self._data_file = {'train': ('wiki.train.tokens',
'863f29c46ef9d167fff4940ec821195882fe29d1'),
'validation': ('wiki.valid.tokens',
'0418625c8b4da6e4b5c7a0b9e78d4ae8f7ee5422'),
'test': ('wiki.test.tokens',
'c7b8ce0aa086fb34dab808c5c49224211eb2b172')}
self._segment = segment
self._seq_len = seq_len
super(WikiText2, self).__init__('wikitext-2', root, vocab, transform)


class WikiText103(_WikiText):
"""WikiText-103 word-level dataset for language modeling, from Salesforce research.
From
Expand All @@ -134,8 +149,9 @@ class WikiText103(WikiText2):
Path to temp folder for storing data.
segment : str, default 'train'
Dataset segment. Options are 'train', 'validation', 'test'.
indexer : :class:`~mxnet.contrib.text.indexer.TokenIndexer`, default None
The indexer to use for indexing the text dataset. If None, a default indexer is created.
vocab : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None
The vocabulary to use for indexing the text dataset.
If None, a default vocabulary is created.
seq_len : int, default 35
The sequence length of each sample, regardless of the sentence boundary.
transform : function, default None
Expand All @@ -146,7 +162,7 @@ class WikiText103(WikiText2):
"""
def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'),
segment='train', indexer=None, seq_len=35, transform=None):
segment='train', vocab=None, seq_len=35, transform=None):
self._archive_file = ('wikitext-103-v1.zip', '0aec09a7537b58d4bb65362fee27650eeaba625a')
self._data_file = {'train': ('wiki.train.tokens',
'b7497e2dfe77e72cfef5e3dbc61b7b53712ac211'),
Expand All @@ -156,5 +172,4 @@ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'),
'8a5befc548865cec54ed4273cf87dbbad60d1e47')}
self._segment = segment
self._seq_len = seq_len
self.indexer = indexer
super(WikiText2, self).__init__('wikitext-103', root, transform) # pylint: disable=bad-super-call
super(WikiText103, self).__init__('wikitext-103', root, vocab, transform)

0 comments on commit 2d42d83

Please sign in to comment.