From 7dcec50f1300c98ba11940a1abedaec7003a5ac7 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Mon, 29 Jan 2018 18:00:46 -0800 Subject: [PATCH] Language Modeling Datasets and Sampler (#9514) * refactor dataset * add interval sampler * wikitext-2/-103 * update word language model * address comments * move interval sampler to contrib * update * add frequencies property --- example/gluon/word_language_model/data.py | 66 ------- example/gluon/word_language_model/train.py | 68 ++++--- python/mxnet/gluon/contrib/__init__.py | 2 + .../mxnet/gluon/contrib/data/__init__.py | 22 +-- python/mxnet/gluon/contrib/data/_constants.py | 22 +++ python/mxnet/gluon/contrib/data/sampler.py | 62 +++++++ python/mxnet/gluon/contrib/data/text.py | 170 ++++++++++++++++++ python/mxnet/gluon/data/dataset.py | 36 ++++ python/mxnet/gluon/data/vision/datasets.py | 53 ++---- tests/python/unittest/test_gluon_contrib.py | 23 +++ 10 files changed, 369 insertions(+), 155 deletions(-) delete mode 100644 example/gluon/word_language_model/data.py rename example/gluon/word_language_model/get_wikitext2_data.sh => python/mxnet/gluon/contrib/data/__init__.py (57%) mode change 100755 => 100644 create mode 100644 python/mxnet/gluon/contrib/data/_constants.py create mode 100644 python/mxnet/gluon/contrib/data/sampler.py create mode 100644 python/mxnet/gluon/contrib/data/text.py diff --git a/example/gluon/word_language_model/data.py b/example/gluon/word_language_model/data.py deleted file mode 100644 index 913963ec20cb..000000000000 --- a/example/gluon/word_language_model/data.py +++ /dev/null @@ -1,66 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os -import numpy as np -import mxnet as mx - -class Dictionary(object): - def __init__(self): - self.word2idx = {} - self.idx2word = [] - - def add_word(self, word): - if word not in self.word2idx: - self.idx2word.append(word) - self.word2idx[word] = len(self.idx2word) - 1 - return self.word2idx[word] - - def __len__(self): - return len(self.idx2word) - - -class Corpus(object): - def __init__(self, path): - self.dictionary = Dictionary() - self.train = self.tokenize(path + 'train.txt') - self.valid = self.tokenize(path + 'valid.txt') - self.test = self.tokenize(path + 'test.txt') - - def tokenize(self, path): - """Tokenizes a text file.""" - assert os.path.exists(path) - # Add words to the dictionary - with open(path, 'r') as f: - tokens = 0 - for line in f: - words = line.split() + [''] - tokens += len(words) - for word in words: - self.dictionary.add_word(word) - - # Tokenize file content - with open(path, 'r') as f: - ids = np.zeros((tokens,), dtype='int32') - token = 0 - for line in f: - words = line.split() + [''] - for word in words: - ids[token] = self.dictionary.word2idx[word] - token += 1 - - return mx.nd.array(ids, dtype='int32') diff --git a/example/gluon/word_language_model/train.py b/example/gluon/word_language_model/train.py index eb584b822ad0..001e9f4930e7 100644 --- a/example/gluon/word_language_model/train.py +++ b/example/gluon/word_language_model/train.py @@ -20,12 +20,11 @@ import math import mxnet as mx from mxnet import gluon, autograd +from mxnet.gluon import contrib import model import data -parser = argparse.ArgumentParser(description='MXNet Autograd PennTreeBank RNN/LSTM Language Model') -parser.add_argument('--data', type=str, default='./data/wikitext-2/wiki.', - help='location of the data corpus') +parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.') parser.add_argument('--model', type=str, default='lstm', help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)') parser.add_argument('--emsize', type=int, default=200, @@ -72,18 +71,33 @@ else: context = mx.cpu(0) -corpus = data.Corpus(args.data) - -def batchify(data, batch_size): - """Reshape data into (num_example, batch_size)""" - nbatch = data.shape[0] // batch_size - data = data[:nbatch * batch_size] - data = data.reshape((batch_size, nbatch)).T - return data - -train_data = batchify(corpus.train, args.batch_size).as_in_context(context) -val_data = batchify(corpus.valid, args.batch_size).as_in_context(context) -test_data = batchify(corpus.test, args.batch_size).as_in_context(context) +train_dataset = contrib.data.text.WikiText2('./data', 'train', seq_len=args.bptt) +vocab = train_dataset.vocabulary +val_dataset, test_dataset = [contrib.data.text.WikiText2('./data', segment, + vocab=vocab, + seq_len=args.bptt) + for segment in ['validation', 'test']] + +nbatch_train = len(train_dataset) / args.batch_size +train_data = gluon.data.DataLoader(train_dataset, + batch_size=args.batch_size, + sampler=contrib.data.IntervalSampler(len(train_dataset), + nbatch_train), + last_batch='discard') + +nbatch_val = len(val_dataset) / args.batch_size +val_data = gluon.data.DataLoader(val_dataset, + batch_size=args.batch_size, + sampler=contrib.data.IntervalSampler(len(val_dataset), + nbatch_val), + last_batch='discard') + +nbatch_test = len(test_dataset) / args.batch_size +test_data = gluon.data.DataLoader(test_dataset, + batch_size=args.batch_size, + sampler=contrib.data.IntervalSampler(len(test_dataset), + nbatch_test), + last_batch='discard') ############################################################################### @@ -91,7 +105,7 @@ def batchify(data, batch_size): ############################################################################### -ntokens = len(corpus.dictionary) +ntokens = len(vocab) model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.tied) model.collect_params().initialize(mx.init.Xavier(), ctx=context) @@ -108,12 +122,6 @@ def batchify(data, batch_size): # Training code ############################################################################### -def get_batch(source, i): - seq_len = min(args.bptt, source.shape[0] - 1 - i) - data = source[i:i+seq_len] - target = source[i+1:i+1+seq_len] - return data, target.reshape((-1,)) - def detach(hidden): if isinstance(hidden, (tuple, list)): hidden = [i.detach() for i in hidden] @@ -125,8 +133,9 @@ def eval(data_source): total_L = 0.0 ntotal = 0 hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context) - for i in range(0, data_source.shape[0] - 1, args.bptt): - data, target = get_batch(data_source, i) + for i, (data, target) in enumerate(data_source): + data = data.as_in_context(context).T + target = target.as_in_context(context).T.reshape((-1, 1)) output, hidden = model(data, hidden) L = loss(output, target) total_L += mx.nd.sum(L).asscalar() @@ -139,15 +148,16 @@ def train(): total_L = 0.0 start_time = time.time() hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context) - for ibatch, i in enumerate(range(0, train_data.shape[0] - 1, args.bptt)): - data, target = get_batch(train_data, i) + for i, (data, target) in enumerate(train_data): + data = data.as_in_context(context).T + target = target.as_in_context(context).T.reshape((-1, 1)) hidden = detach(hidden) with autograd.record(): output, hidden = model(data, hidden) L = loss(output, target) L.backward() - grads = [i.grad(context) for i in model.collect_params().values()] + grads = [p.grad(context) for p in model.collect_params().values()] # Here gradient is for the whole batch. # So we multiply max_norm by batch_size and bptt size to balance it. gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size) @@ -155,10 +165,10 @@ def train(): trainer.step(args.batch_size) total_L += mx.nd.sum(L).asscalar() - if ibatch % args.log_interval == 0 and ibatch > 0: + if i % args.log_interval == 0 and i > 0: cur_L = total_L / args.bptt / args.batch_size / args.log_interval print('[Epoch %d Batch %d] loss %.2f, ppl %.2f'%( - epoch, ibatch, cur_L, math.exp(cur_L))) + epoch, i, cur_L, math.exp(cur_L))) total_L = 0.0 val_L = eval(val_data) diff --git a/python/mxnet/gluon/contrib/__init__.py b/python/mxnet/gluon/contrib/__init__.py index e06438b4f80c..f708fb900227 100644 --- a/python/mxnet/gluon/contrib/__init__.py +++ b/python/mxnet/gluon/contrib/__init__.py @@ -21,3 +21,5 @@ from . import nn from . import rnn + +from . import data diff --git a/example/gluon/word_language_model/get_wikitext2_data.sh b/python/mxnet/gluon/contrib/data/__init__.py old mode 100755 new mode 100644 similarity index 57% rename from example/gluon/word_language_model/get_wikitext2_data.sh rename to python/mxnet/gluon/contrib/data/__init__.py index e9b8461c4005..7cb25eb7498e --- a/example/gluon/word_language_model/get_wikitext2_data.sh +++ b/python/mxnet/gluon/contrib/data/__init__.py @@ -1,5 +1,3 @@ -#!/usr/bin/env bash - # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,20 +15,10 @@ # specific language governing permissions and limitations # under the License. +# coding: utf-8 +# pylint: disable=wildcard-import +"""Contrib datasets.""" -RNN_DIR=$(cd `dirname $0`; pwd) -DATA_DIR="${RNN_DIR}/data/" - -if [[ ! -d "${DATA_DIR}" ]]; then - echo "${DATA_DIR} doesn't exist, will create one"; - mkdir -p ${DATA_DIR} -fi - -wget -P ${DATA_DIR} https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip -cd ${DATA_DIR} -unzip wikitext-2-v1.zip +from . import text -# rename -mv ${DATA_DIR}/wikitext-2/wiki.test.tokens ${DATA_DIR}/wikitext-2/wiki.test.txt -mv ${DATA_DIR}/wikitext-2/wiki.valid.tokens ${DATA_DIR}/wikitext-2/wiki.valid.txt -mv ${DATA_DIR}/wikitext-2/wiki.train.tokens ${DATA_DIR}/wikitext-2/wiki.train.txt +from .sampler import * diff --git a/python/mxnet/gluon/contrib/data/_constants.py b/python/mxnet/gluon/contrib/data/_constants.py new file mode 100644 index 000000000000..86974dce4502 --- /dev/null +++ b/python/mxnet/gluon/contrib/data/_constants.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 + +"""Read text files and load embeddings.""" + +EOS_TOKEN = '' diff --git a/python/mxnet/gluon/contrib/data/sampler.py b/python/mxnet/gluon/contrib/data/sampler.py new file mode 100644 index 000000000000..91136bd2368e --- /dev/null +++ b/python/mxnet/gluon/contrib/data/sampler.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable= +"""Dataset sampler.""" +__all__ = ['IntervalSampler'] + +from ...data import sampler + +class IntervalSampler(sampler.Sampler): + """Samples elements from [0, length) at fixed intervals. + + Parameters + ---------- + length : int + Length of the sequence. + interval : int + The number of items to skip between two samples. + rollover : bool, default True + Whether to start again from the first skipped item after reaching the end. + If true, this sampler would start again from the first skipped item until all items + are visited. + Otherwise, iteration stops when end is reached and skipped items are ignored. + + Examples + -------- + >>> sampler = contrib.data.IntervalSampler(13, interval=3) + >>> list(sampler) + [0, 3, 6, 9, 12, 1, 4, 7, 10, 2, 5, 8, 11] + >>> sampler = contrib.data.IntervalSampler(13, interval=3, rollover=False) + >>> list(sampler) + [0, 3, 6, 9, 12] + """ + def __init__(self, length, interval, rollover=True): + assert interval < length, \ + "Interval {} must be smaller than length {}".format(interval, length) + self._length = length + self._interval = interval + self._rollover = rollover + + def __iter__(self): + for i in range(self._interval if self._rollover else 1): + for j in range(i, self._length, self._interval): + yield j + + def __len__(self): + return self._length diff --git a/python/mxnet/gluon/contrib/data/text.py b/python/mxnet/gluon/contrib/data/text.py new file mode 100644 index 000000000000..82f780942a03 --- /dev/null +++ b/python/mxnet/gluon/contrib/data/text.py @@ -0,0 +1,170 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable= +"""Text datasets.""" +__all__ = ['WikiText2', 'WikiText103'] + +import io +import os +import zipfile +import shutil +import numpy as np + +from . import _constants as C +from ...data import dataset +from ...utils import download, check_sha1 +from ....contrib import text +from .... import nd + + +class _LanguageModelDataset(dataset._DownloadedDataset): # pylint: disable=abstract-method + def __init__(self, repo_dir, root, vocabulary): + self._vocab = vocabulary + self._counter = None + super(_LanguageModelDataset, self).__init__(repo_dir, root, None) + + @property + def vocabulary(self): + return self._vocab + + @property + def frequencies(self): + return self._counter + + def _build_vocab(self, content): + if not self._counter: + self._counter = text.utils.count_tokens_from_str(content) + if not self._vocab: + self._vocab = text.vocab.Vocabulary(counter=self.frequencies, + reserved_tokens=[C.EOS_TOKEN]) + + +class _WikiText(_LanguageModelDataset): + + def _read_batch(self, filename): + with io.open(filename, 'r', encoding='utf8') as fin: + content = fin.read() + self._build_vocab(content) + + raw_data = [line for line in [x.strip().split() for x in content.splitlines()] + if line] + for line in raw_data: + line.append(C.EOS_TOKEN) + raw_data = self.vocabulary.to_indices([x for x in line for line in raw_data if x]) + data = raw_data[0:-1] + label = raw_data[1:] + return np.array(data, dtype=np.int32), np.array(label, dtype=np.int32) + + def _get_data(self): + archive_file_name, archive_hash = self._archive_file + data_file_name, data_hash = self._data_file[self._segment] + path = os.path.join(self._root, data_file_name) + if not os.path.exists(path) or not check_sha1(path, data_hash): + downloaded_file_path = download(self._get_url(archive_file_name), + path=self._root, + sha1_hash=archive_hash) + + with zipfile.ZipFile(downloaded_file_path, 'r') as zf: + for member in zf.namelist(): + filename = os.path.basename(member) + if filename: + dest = os.path.join(self._root, filename) + with zf.open(member) as source, \ + open(dest, "wb") as target: + shutil.copyfileobj(source, target) + + data, label = self._read_batch(os.path.join(self._root, data_file_name)) + + self._data = nd.array(data, dtype=data.dtype).reshape((-1, self._seq_len)) + self._label = nd.array(label, dtype=label.dtype).reshape((-1, self._seq_len)) + + +class WikiText2(_WikiText): + """WikiText-2 word-level dataset for language modeling, from Salesforce research. + + From + https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset + + License: Creative Commons Attribution-ShareAlike + + Each sample is a vector of length equal to the specified sequence length. + At the end of each sentence, an end-of-sentence token '' is added. + + Parameters + ---------- + root : str, default '~/.mxnet/datasets/cifar10' + Path to temp folder for storing data. + segment : str, default 'train' + Dataset segment. Options are 'train', 'validation', 'test'. + vocab : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None + The vocabulary to use for indexing the text dataset. + If None, a default vocabulary is created. + seq_len : int, default 35 + The sequence length of each sample, regardless of the sentence boundary. + + """ + def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-2'), + segment='train', vocab=None, seq_len=35): + self._archive_file = ('wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe') + self._data_file = {'train': ('wiki.train.tokens', + '863f29c46ef9d167fff4940ec821195882fe29d1'), + 'validation': ('wiki.valid.tokens', + '0418625c8b4da6e4b5c7a0b9e78d4ae8f7ee5422'), + 'test': ('wiki.test.tokens', + 'c7b8ce0aa086fb34dab808c5c49224211eb2b172')} + self._segment = segment + self._seq_len = seq_len + super(WikiText2, self).__init__('wikitext-2', root, vocab) + + +class WikiText103(_WikiText): + """WikiText-103 word-level dataset for language modeling, from Salesforce research. + + From + https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset + + License: Creative Commons Attribution-ShareAlike + + Each sample is a vector of length equal to the specified sequence length. + At the end of each sentence, an end-of-sentence token '' is added. + + Parameters + ---------- + root : str, default '~/.mxnet/datasets/cifar10' + Path to temp folder for storing data. + segment : str, default 'train' + Dataset segment. Options are 'train', 'validation', 'test'. + vocab : :class:`~mxnet.contrib.text.vocab.Vocabulary`, default None + The vocabulary to use for indexing the text dataset. + If None, a default vocabulary is created. + seq_len : int, default 35 + The sequence length of each sample, regardless of the sentence boundary. + """ + def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'wikitext-103'), + segment='train', vocab=None, seq_len=35): + self._archive_file = ('wikitext-103-v1.zip', '0aec09a7537b58d4bb65362fee27650eeaba625a') + self._data_file = {'train': ('wiki.train.tokens', + 'b7497e2dfe77e72cfef5e3dbc61b7b53712ac211'), + 'validation': ('wiki.valid.tokens', + 'c326ac59dc587676d58c422eb8a03e119582f92b'), + 'test': ('wiki.test.tokens', + '8a5befc548865cec54ed4273cf87dbbad60d1e47')} + self._segment = segment + self._seq_len = seq_len + super(WikiText103, self).__init__('wikitext-103', root, vocab) diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index 4b97e4369cb6..fe1e81326427 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -181,3 +181,39 @@ def __getitem__(self, idx): def __len__(self): return len(self._record.keys) + +apache_repo_url = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/' + +class _DownloadedDataset(Dataset): + """Base class for MNIST, cifar10, etc.""" + def __init__(self, repo_dir, root, transform): + self._root = os.path.expanduser(root) + self._repo_dir = repo_dir + self._transform = transform + self._data = None + self._label = None + + repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url) + if repo_url[-1] != '/': + repo_url = repo_url+'/' + self._base_url = repo_url + + if not os.path.isdir(self._root): + os.makedirs(self._root) + self._get_data() + + def __getitem__(self, idx): + if self._transform is not None: + return self._transform(self._data[idx], self._label[idx]) + return self._data[idx], self._label[idx] + + def __len__(self): + return len(self._label) + + def _get_data(self): + raise NotImplementedError + + def _get_url(self, filename): + return '{base_url}gluon/dataset/{repo_dir}/{filename}'.format(base_url=self._base_url, + repo_dir=self._repo_dir, + filename=filename) diff --git a/python/mxnet/gluon/data/vision/datasets.py b/python/mxnet/gluon/data/vision/datasets.py index 4ddc2e3ede73..4568da67a95d 100644 --- a/python/mxnet/gluon/data/vision/datasets.py +++ b/python/mxnet/gluon/data/vision/datasets.py @@ -32,45 +32,8 @@ from ...utils import download, check_sha1 from .... import nd, image, recordio -apache_repo_url = 'https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/' -class _DownloadedDataset(dataset.Dataset): - """Base class for MNIST, cifar10, etc.""" - def __init__(self, repo_dir, root, train, transform): - self._root = os.path.expanduser(root) - self._repo_dir = repo_dir - self._train = train - self._transform = transform - self._data = None - self._label = None - - repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url) - if repo_url[-1] != '/': - repo_url = repo_url+'/' - self._base_url = repo_url - - if not os.path.isdir(self._root): - os.makedirs(self._root) - self._get_data() - - def __getitem__(self, idx): - if self._transform is not None: - return self._transform(self._data[idx], self._label[idx]) - return self._data[idx], self._label[idx] - - def __len__(self): - return len(self._label) - - def _get_data(self): - raise NotImplementedError - - def _get_url(self, filename): - return '{base_url}gluon/dataset/{repo_dir}/{filename}'.format(base_url=self._base_url, - repo_dir=self._repo_dir, - filename=filename) - - -class MNIST(_DownloadedDataset): +class MNIST(dataset._DownloadedDataset): """MNIST handwritten digits dataset from http://yann.lecun.com/exdb/mnist Each sample is an image (in 3D NDArray) with shape (28, 28, 1). @@ -90,6 +53,7 @@ class MNIST(_DownloadedDataset): """ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'mnist'), train=True, transform=None): + self._train = train self._train_data = ('train-images-idx3-ubyte.gz', '6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d') self._train_label = ('train-labels-idx1-ubyte.gz', @@ -98,7 +62,7 @@ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'mnist'), 'c3a25af1f52dad7f726cce8cacb138654b760d48') self._test_label = ('t10k-labels-idx1-ubyte.gz', '763e7fa3757d93b0cdec073cef058b2004252c17') - super(MNIST, self).__init__('mnist', root, train, transform) + super(MNIST, self).__init__('mnist', root, transform) def _get_data(self): if self._train: @@ -148,6 +112,7 @@ class FashionMNIST(MNIST): """ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'fashion-mnist'), train=True, transform=None): + self._train = train self._train_data = ('train-images-idx3-ubyte.gz', '0cf37b0d40ed5169c6b3aba31069a9770ac9043d') self._train_label = ('train-labels-idx1-ubyte.gz', @@ -156,10 +121,10 @@ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'fashion-mnist') '626ed6a7c06dd17c0eec72fa3be1740f146a2863') self._test_label = ('t10k-labels-idx1-ubyte.gz', '17f9ab60e7257a1620f4ad76bbbaf857c3920701') - super(MNIST, self).__init__('fashion-mnist', root, train, transform) # pylint: disable=bad-super-call + super(MNIST, self).__init__('fashion-mnist', root, transform) # pylint: disable=bad-super-call -class CIFAR10(_DownloadedDataset): +class CIFAR10(dataset._DownloadedDataset): """CIFAR10 image classification dataset from https://www.cs.toronto.edu/~kriz/cifar.html Each sample is an image (in 3D NDArray) with shape (32, 32, 1). @@ -179,6 +144,7 @@ class CIFAR10(_DownloadedDataset): """ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar10'), train=True, transform=None): + self._train = train self._archive_file = ('cifar-10-binary.tar.gz', 'fab780a1e191a7eda0f345501ccd62d20f7ed891') self._train_data = [('data_batch_1.bin', 'aadd24acce27caa71bf4b10992e9e7b2d74c2540'), ('data_batch_2.bin', 'c0ba65cce70568cd57b4e03e9ac8d2a5367c1795'), @@ -186,7 +152,7 @@ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar10'), ('data_batch_4.bin', 'aab85764eb3584312d3c7f65fd2fd016e36a258e'), ('data_batch_5.bin', '26e2849e66a845b7f1e4614ae70f4889ae604628')] self._test_data = [('test_batch.bin', '67eb016db431130d61cd03c7ad570b013799c88c')] - super(CIFAR10, self).__init__('cifar10', root, train, transform) + super(CIFAR10, self).__init__('cifar10', root, transform) def _read_batch(self, filename): with open(filename, 'rb') as fin: @@ -241,11 +207,12 @@ class CIFAR100(CIFAR10): """ def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'cifar100'), fine_label=False, train=True, transform=None): + self._train = train self._archive_file = ('cifar-100-binary.tar.gz', 'a0bb982c76b83111308126cc779a992fa506b90b') self._train_data = [('train.bin', 'e207cd2e05b73b1393c74c7f5e7bea451d63e08e')] self._test_data = [('test.bin', '8fb6623e830365ff53cf14adec797474f5478006')] self._fine_label = fine_label - super(CIFAR10, self).__init__('cifar100', root, train, transform) # pylint: disable=bad-super-call + super(CIFAR10, self).__init__('cifar100', root, transform) # pylint: disable=bad-super-call def _read_batch(self, filename): with open(filename, 'rb') as fin: diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index 1a188c34b147..2f8558f33840 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -172,6 +172,29 @@ def test_identity(): mx.test_utils.assert_almost_equal(model(x).asnumpy(), x.asnumpy()) +def test_datasets(): + wikitext2_train = contrib.data.text.WikiText2(root='data/wikitext-2', segment='train') + wikitext2_val = contrib.data.text.WikiText2(root='data/wikitext-2', segment='validation', + vocab=wikitext2_train.vocabulary) + wikitext2_test = contrib.data.text.WikiText2(root='data/wikitext-2', segment='test') + assert len(wikitext2_train) == 42780 + assert len(wikitext2_train.vocabulary) == 33278 + assert len(wikitext2_train.frequencies) == 33277 + assert len(wikitext2_val) == 632 + assert len(wikitext2_val.vocabulary) == 33278 + assert len(wikitext2_val.frequencies) == 13776 + assert len(wikitext2_test) == 15941 + assert len(wikitext2_test.vocabulary) == 14143, len(wikitext2_test.vocabulary) + assert len(wikitext2_test.frequencies) == 14142, len(wikitext2_test.frequencies) + assert wikitext2_test.frequencies['English'] == 32 + + +def test_sampler(): + interval_sampler = contrib.data.IntervalSampler(10, 3) + assert sorted(list(interval_sampler)) == list(range(10)) + interval_sampler = contrib.data.IntervalSampler(10, 3, rollover=False) + assert list(interval_sampler) == [0, 3, 6, 9] + if __name__ == '__main__': import nose