From 60cac0ba721557f45a96770b9c13212800924b24 Mon Sep 17 00:00:00 2001 From: eric-haibin-lin Date: Wed, 12 Jul 2017 23:48:11 +0000 Subject: [PATCH] update module API for other submodules update stypes in kvstore after refactoring change type of size from size_t to int64_t add sparse linear regression example remove sparse_pull_dict from module fix init_optim for seq_module. update sparse example --- example/sparse/get_data.py | 15 ++ example/sparse/linear_regression.py | 178 +++++++++++++++++++ python/mxnet/kvstore.py | 15 +- python/mxnet/model.py | 42 +++-- python/mxnet/module/base_module.py | 3 +- python/mxnet/module/module.py | 18 +- src/kvstore/kvstore_dist.h | 10 +- tests/nightly/dist_sync_kvstore.py | 4 +- tests/python/unittest/test_kvstore.py | 4 +- tests/python/unittest/test_sparse_ndarray.py | 18 +- 10 files changed, 250 insertions(+), 57 deletions(-) create mode 100644 example/sparse/get_data.py create mode 100644 example/sparse/linear_regression.py diff --git a/example/sparse/get_data.py b/example/sparse/get_data.py new file mode 100644 index 000000000000..6b6723f07be4 --- /dev/null +++ b/example/sparse/get_data.py @@ -0,0 +1,15 @@ +# pylint: skip-file +import os, gzip +import pickle as pickle +import sys + +def get_libsvm_data(data_dir, data_name, url, data_origin_name): + if not os.path.isdir(data_dir): + os.system("mkdir " + data_dir) + os.chdir(data_dir) + if (not os.path.exists(data_name)): + import urllib + zippath = os.path.join(data_dir, data_origin_name) + urllib.urlretrieve(url, zippath) + os.system("bzip2 -d %r" % data_origin_name) + os.chdir("..") diff --git a/example/sparse/linear_regression.py b/example/sparse/linear_regression.py new file mode 100644 index 000000000000..6aa1cbadbcb2 --- /dev/null +++ b/example/sparse/linear_regression.py @@ -0,0 +1,178 @@ +import mxnet as mx +from mxnet.test_utils import * +from get_data import get_libsvm_data +import time +import argparse +import os + +parser = argparse.ArgumentParser(description="Run sparse linear regression " \ + "with distributed kvstore", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--profiler', type=int, default=0, + help='whether to use profiler') +parser.add_argument('--num-epoch', type=int, default=1, + help='number of epochs to train') +parser.add_argument('--batch-size', type=int, default=512, + help='number of examples per batch') +parser.add_argument('--num-batch', type=int, default=99999999, + help='number of batches per epoch') +parser.add_argument('--dummy-iter', type=int, default=0, + help='whether to use dummy iterator to exclude io cost') +parser.add_argument('--kvstore', type=str, default='dist_sync', + help='what kvstore to use [local, dist_sync, etc]') +parser.add_argument('--log-level', type=str, default='debug', + help='logging level [debug, info, error]') +parser.add_argument('--dataset', type=str, default='avazu', + help='what test dataset to use') + +class DummyIter(mx.io.DataIter): + "A dummy iterator that always return the same batch, used for speed testing" + def __init__(self, real_iter): + super(DummyIter, self).__init__() + self.real_iter = real_iter + self.provide_data = real_iter.provide_data + self.provide_label = real_iter.provide_label + self.batch_size = real_iter.batch_size + + for batch in real_iter: + self.the_batch = batch + break + + def __iter__(self): + return self + + def next(self): + return self.the_batch + +# testing dataset sources +avazu = { + 'data_name': 'avazu-app.t', + 'data_origin_name': 'avazu-app.t.bz2', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2", + 'feature_dim': 1000000, +} + +kdda = { + 'data_name': 'kdda.t', + 'data_origin_name': 'kdda.t.bz2', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2", + 'feature_dim': 20216830, +} + +datasets = { 'kdda' : kdda, 'avazu' : avazu } + +def regression_model(feature_dim): + initializer = mx.initializer.Normal() + x = mx.symbol.Variable("data", stype='csr') + norm_init = mx.initializer.Normal(sigma=0.01) + v = mx.symbol.Variable("v", shape=(feature_dim, 1), init=norm_init, stype='row_sparse') + embed = mx.symbol.dot(x, v) + y = mx.symbol.Variable("softmax_label") + model = mx.symbol.LinearRegressionOutput(data=embed, label=y, name="out") + return model + +if __name__ == '__main__': + + # arg parser + args = parser.parse_args() + num_epoch = args.num_epoch + num_batch = args.num_batch + kvstore = args.kvstore + profiler = args.profiler > 0 + batch_size = args.batch_size + dummy_iter = args.dummy_iter + dataset = args.dataset + log_level = args.log_level + + # create kvstore + kv = mx.kvstore.create(kvstore) + rank = kv.rank + num_worker = kv.num_workers + + # only print log for rank 0 worker + import logging + if rank != 0: + log_level = logging.ERROR + elif log_level == 'DEBUG': + log_level = logging.DEBUG + else: + log_level = logging.INFO + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=log_level, format=head) + + # dataset + assert(dataset in datasets), "unknown dataset " + dataset + metadata = datasets[dataset] + feature_dim = metadata['feature_dim'] + if logging: + logging.debug('preparing data ... ') + data_dir = os.path.join(os.getcwd(), 'data') + path = os.path.join(data_dir, metadata['data_name']) + if not os.path.exists(path): + get_libsvm_data(data_dir, metadata['data_name'], metadata['url'], + metadata['data_origin_name']) + assert os.path.exists(path) + + # data iterator + train_data = mx.io.LibSVMIter(data_libsvm=path, data_shape=(feature_dim,), + batch_size=batch_size, num_parts=num_worker, + part_index=rank) + if dummy_iter: + train_data = DummyIter(train_data) + + # model + model = regression_model(feature_dim) + + # module + mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['softmax_label']) + mod.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label) + mod.init_params(initializer=mx.init.Uniform(scale=.1)) + sgd = mx.optimizer.SGD(momentum=0.0, clip_gradient=5.0, + learning_rate=0.1, rescale_grad=1.0/batch_size/num_worker) + mod.init_optimizer(optimizer=sgd, kvstore=kv) + # use accuracy as the metric + metric = mx.metric.create('MSE') + + # start profiler + if profiler: + import random + name = 'profile_output_' + str(num_worker) + '.json' + mx.profiler.profiler_set_config(mode='all', filename=name) + mx.profiler.profiler_set_state('run') + + logging.debug('start training ...') + start = time.time() + data_iter = iter(train_data) + for epoch in range(num_epoch): + nbatch = 0 + end_of_batch = False + data_iter.reset() + metric.reset() + next_batch = next(data_iter) + while not end_of_batch: + nbatch += 1 + batch = next_batch + # TODO(haibin) remove extra copy after Jun's change + row_ids = batch.data[0].indices.copyto(mx.cpu()) + # pull sparse weight + index = mod._exec_group.param_names.index('v') + kv.row_sparse_pull('v', mod._exec_group.param_arrays[index], + priority=-index, row_ids=[row_ids]) + mod.forward_backward(batch) + # update parameters + mod.update() + try: + # pre fetch next batch + next_batch = next(data_iter) + if nbatch == num_batch: + raise StopIteration + except StopIteration: + end_of_batch = True + # accumulate prediction accuracy + mod.update_metric(metric, batch.label) + logging.info('epoch %d, %s' % (epoch, metric.get())) + if profiler: + mx.profiler.profiler_set_state('stop') + end = time.time() + time_cost = end - start + logging.info('num_worker = ' + str(num_worker) + ', time cost = ' + str(time_cost)) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 1fced8dd452d..8d96c751ccb3 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -5,7 +5,7 @@ import ctypes import pickle from .ndarray import NDArray -from .sparse_ndarray import _ndarray_cls +from .ndarray import _ndarray_cls from .base import _LIB from .base import check_call, c_array, c_str, string_types, mx_uint, py_str from .base import NDArrayHandle, KVStoreHandle @@ -221,10 +221,10 @@ def pull(self, key, out=None, priority=0): out = [out] for val in out: if not isinstance(val, (list, tuple)): - assert(val.storage_type == 'default') + assert(val.stype == 'default') else: for v in val: - assert(v.storage_type == 'default') + assert(v.stype == 'default') ckeys, cvals = _ctype_key_value(key, out) check_call(_LIB.MXKVStorePullEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, @@ -245,7 +245,7 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): Keys. out: NDArray or list of NDArray or list of list of NDArray - Values corresponding to the keys. The storage_type is expected to be row_sparse + Values corresponding to the keys. The stype is expected to be row_sparse priority : int, optional The priority of the pull operation. @@ -287,14 +287,13 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): out = [out] for val in out: if not isinstance(val, (list, tuple)): - assert(val.storage_type == 'row_sparse') + assert(val.stype == 'row_sparse') else: for v in val: - assert(v.storage_type == 'row_sparse') + assert(v.stype == 'row_sparse') ckeys, cvals = _ctype_key_value(key, out) _, crow_ids = _ctype_key_value(key, row_ids) - assert(len(crow_ids) == len(cvals)), (len(crow_ids), len(cvals)) - #TODO(haibin) pickup upstream changes which removed `_cast_to_str_keys` + assert(len(crow_ids) == len(cvals)), "number of row_ids doesn't match number of values" check_call(_LIB.MXKVStorePullRowSparse( self.handle, mx_uint(len(ckeys)), ckeys, cvals, crow_ids, ctypes.c_int(priority))) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index ae0200fba732..e30f9f332c8c 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -76,23 +76,31 @@ def _create_kvstore(kvstore, num_device, arg_params): return (kv, update_on_kvstore) -def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, - update_on_kvstore, sparse_pull_dict=None): +def _contains_non_default_storage(params): + if isinstance(params, (list, tuple)): + for param in params: + if param.stype != 'default': + return True + elif isinstance(params, NDArray): + return param.stype != 'default' + else: + return False + +def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, update_on_kvstore): """Initialize kvstore""" for idx, param_on_devs in enumerate(param_arrays): name = param_names[idx] kvstore.init(name, arg_params[name]) if update_on_kvstore: - if sparse_pull_dict is not None and name in sparse_pull_dict: - kvstore.row_sparse_pull(name, param_on_devs, priority=-idx, - row_ids=sparse_pull_dict[name]) + if _contains_non_default_storage(param_on_devs): + # skip pulling row_sparse weights + warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \ + 'sure to pull it with row_ids explicitly', RuntimeWarning) else: kvstore.pull(name, param_on_devs, priority=-idx) -def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names, - sparse_pull_dict=None): - +def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): """Perform update of param_arrays from grad_arrays on kvstore.""" for index, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair @@ -102,14 +110,15 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names, # push gradient, priority is negative index kvstore.push(name, grad_list, priority=-index) # pull back the weights - if sparse_pull_dict is not None and name in sparse_pull_dict: - kvstore.row_sparse_pull(name, arg_list, priority=-index, - row_ids=sparse_pull_dict[name]) + if _contains_non_default_storage(arg_list): + # skip pulling row_sparse weights + warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \ + 'sure to pull it with row_ids', RuntimeWarning) else: kvstore.pull(name, arg_list, priority=-index) def _update_params(param_arrays, grad_arrays, updater, num_device, - kvstore=None, param_names=None, sparse_pull_dict=None): + kvstore=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair @@ -120,11 +129,12 @@ def _update_params(param_arrays, grad_arrays, updater, num_device, name = param_names[index] # push gradient, priority is negative index kvstore.push(name, grad_list, priority=-index) - if sparse_pull_dict is not None and name in sparse_pull_dict: - kvstore.row_sparse_pull(name, grad_list, priority=-index, - row_ids=sparse_pull_dict[name]) + # pull back the sum gradients, to the same locations. + if _contains_non_default_storage(grad_list): + # skip pulling row_sparse weights + warnings.warn('Detected non-default weight in kvstore to pull. Please make ' \ + 'sure to pull it with row_ids', RuntimeWarning) else: - # pull back the sum gradients, to the same locations. kvstore.pull(name, grad_list, priority=-index) for k, p in enumerate(zip(arg_list, grad_list)): # faked an index here, to make optimizer create diff diff --git a/python/mxnet/module/base_module.py b/python/mxnet/module/base_module.py index 935b2a9b09cf..05076cec46b7 100644 --- a/python/mxnet/module/base_module.py +++ b/python/mxnet/module/base_module.py @@ -932,8 +932,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, raise NotImplementedError() def init_optimizer(self, kvstore='local', optimizer='sgd', - optimizer_params=(('learning_rate', 0.01),), force_init=False, - sparse_pull_dict=None): + optimizer_params=(('learning_rate', 0.01),), force_init=False): """Installs and initializes optimizers, as well as initialize kvstore for distributed training diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index 77a9ab91ee90..1594665bf5ef 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -429,8 +429,7 @@ def reshape(self, data_shapes, label_shapes=None): self._exec_group.reshape(self._data_shapes, self._label_shapes) def init_optimizer(self, kvstore='local', optimizer='sgd', - optimizer_params=(('learning_rate', 0.01),), force_init=False, - sparse_pull_dict=None): + optimizer_params=(('learning_rate', 0.01),), force_init=False): """Installs and initializes optimizers. Parameters @@ -445,10 +444,6 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', force_init : bool Default ``False``, indicating whether we should force re-initializing the optimizer in the case an optimizer is already installed. - sparse_pull_dict : dict of str -> list of NDArray - Default to `None`, used for distributed training with sparse parameters. - When the name of a row_sparse parameter is in the dict, the initial value pulled - to devices will only contain the rows specified by the list of row_id NDArrays. """ assert self.binded and self.params_initialized @@ -502,8 +497,7 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', param_arrays=self._exec_group.param_arrays, arg_params=self._arg_params, param_names=self._param_names, - update_on_kvstore=update_on_kvstore, - sparse_pull_dict=sparse_pull_dict) + update_on_kvstore=update_on_kvstore) if update_on_kvstore: kvstore.set_optimizer(self._optimizer) else: @@ -564,7 +558,7 @@ def backward(self, out_grads=None): assert self.binded and self.params_initialized self._exec_group.backward(out_grads=out_grads) - def update(self, sparse_pull_dict=None): + def update(self): """Updates parameters according to the installed optimizer and the gradients computed in the previous forward-backward batch. @@ -578,16 +572,14 @@ def update(self, sparse_pull_dict=None): if self._update_on_kvstore: _update_params_on_kvstore(self._exec_group.param_arrays, self._exec_group.grad_arrays, - self._kvstore, self._exec_group.param_names, - sparse_pull_dict=sparse_pull_dict) + self._kvstore, self._exec_group.param_names) else: _update_params(self._exec_group.param_arrays, self._exec_group.grad_arrays, updater=self._updater, num_device=len(self._context), kvstore=self._kvstore, - param_names=self._exec_group.param_names, - sparse_pull_dict=sparse_pull_dict) + param_names=self._exec_group.param_names) def get_outputs(self, merge_multi_context=True): """Gets outputs of the previous forward computation. diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 14a324514cff..9a3bc31f4ac4 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -316,7 +316,7 @@ class KVStoreDist : public KVStoreLocal { auto indices_data = indices.data(); const auto offsets = indices_data.dptr(); const auto unit_len = recv_buf->shape().ProdShape(1, recv_buf->shape().ndim()); - size_t size = num_rows * unit_len; + const int64_t size = num_rows * unit_len; // convert to ps keys in row sparse format PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, unit_len, recv_buf->shape()[0]); @@ -351,10 +351,10 @@ class KVStoreDist : public KVStoreLocal { #endif real_t* data = static_cast(send_buf.data().dptr_); bool init = send_buf.storage_initialized(); - size_t num_rows = init ? send_buf.aux_shape(kIdx).Size() : 0; + const int64_t num_rows = init ? send_buf.aux_shape(kIdx)[0] : 0; const auto offsets = init ? send_buf.aux_data(kIdx).dptr() : nullptr; const auto unit_len = send_buf.shape().ProdShape(1, send_buf.shape().ndim()); - const auto size = num_rows * unit_len; + const int64_t size = num_rows * unit_len; // convert to ps keys in row sparse format PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, @@ -451,7 +451,7 @@ class KVStoreDist : public KVStoreLocal { } // TODO(haibin) this encoding method for row sparse keys doesn't allow cross-layer batching - inline PSKV& EncodeRowSparseKey(const int key, const size_t size, const int64_t num_rows, + inline PSKV& EncodeRowSparseKey(const int key, const int64_t size, const int64_t num_rows, const int64_t *offsets, const size_t unit_len, const int64_t total_num_rows) { using namespace common; @@ -481,7 +481,7 @@ class KVStoreDist : public KVStoreLocal { pskv.keys.push_back(master_key); pskv.lens.push_back(0); for (auto offset = lb; offset < ub; offset++) { - ps::Key ps_key = krs[i].begin() + key + *offset - start_row; + ps::Key ps_key = krs[i].begin() + key + (*offset - start_row); CHECK_LT(ps_key, krs[i].end()); pskv.keys.push_back(ps_key); pskv.lens.push_back(unit_len); diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index 53186a2ab804..f88b412b027c 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -66,7 +66,7 @@ def check_row_sparse_keys(kv, my_rank, nworker): row_ids_np = np.random.randint(num_rows, size=num_rows) row_ids = mx.nd.array(row_ids_np, dtype='int64') # perform pull - val = mx.nd.zeros(shape, storage_type='row_sparse') + val = mx.nd.zeros(shape, stype='row_sparse') kv.row_sparse_pull('9', out=val, row_ids=row_ids) # prepare updated values updated_val = mx.nd.ones(shape) @@ -134,7 +134,7 @@ def check_big_row_sparse_keys(kv, my_rank, nworker): row_ids_np = np.random.randint(num_rows, size=num_rows) row_ids = mx.nd.array(row_ids_np, dtype='int64') # perform pull - val = mx.nd.zeros(big_shape, storage_type='row_sparse') + val = mx.nd.zeros(big_shape, stype='row_sparse') kv.row_sparse_pull('100', out=val, row_ids=row_ids) # prepare expected result updated_val = mx.nd.ones(big_shape) diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index a0dd870d3106..665467854977 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -20,9 +20,9 @@ def init_kv_with_str(stype='default'): """init kv """ kv = mx.kv.create() # single - kv.init('a', mx.nd.zeros(shape, storage_type=stype)) + kv.init('a', mx.nd.zeros(shape, stype=stype)) # list - kv.init(str_keys, [mx.nd.zeros(shape=shape, storage_type=stype)] * len(keys)) + kv.init(str_keys, [mx.nd.zeros(shape=shape, stype=stype)] * len(keys)) return kv def check_diff_to_scalar(A, x): diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 097f650db73d..76e121e462dc 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -21,10 +21,10 @@ def sparse_nd_ones(shape, stype): def check_sparse_nd_elemwise_binary(shapes, stypes, f, g): # generate inputs nds = [] - for i, storage_type in enumerate(stypes): - if storage_type == 'row_sparse': - nd, _ = rand_sparse_ndarray(shapes[i], storage_type) - elif storage_type == 'default': + for i, stype in enumerate(stypes): + if stype == 'row_sparse': + nd, _ = rand_sparse_ndarray(shapes[i], stype) + elif stype == 'default': nd = mx.nd.array(random_arrays(shapes[i]), dtype = np.float32) else: assert(False) @@ -78,7 +78,7 @@ def check_sparse_nd_copy(from_stype, to_stype, shape): check_sparse_nd_copy('default', 'row_sparse', shape) check_sparse_nd_copy('default', 'csr', shape) check_sparse_nd_copy('row_sparse', 'row_sparse', shape_3d) - check_sparse_nd_copy('default', 'row_sparse', shape_3d) + def test_sparse_nd_basic(): def check_rsp_creation(values, indices, shape): @@ -140,8 +140,8 @@ def check_sparse_nd_setitem(stype, shape, dst): def test_sparse_nd_slice(): def check_sparse_nd_csr_slice(shape): - storage_type = 'csr' - A, _ = rand_sparse_ndarray(shape, storage_type) + stype = 'csr' + A, _ = rand_sparse_ndarray(shape, stype) A2 = A.asnumpy() start = rnd.randint(0, shape[0] - 1) end = rnd.randint(start + 1, shape[0]) @@ -307,8 +307,8 @@ def check_binary(fn, stype): rshape = list(oshape) lhs = np.random.uniform(0, 1, size=lshape) rhs = np.random.uniform(0, 1, size=rshape) - lhs_nd = mx.nd.cast_storage(mx.nd.array(lhs), storage_type=stype) - rhs_nd = mx.nd.cast_storage(mx.nd.array(rhs), storage_type=stype) + lhs_nd = mx.nd.cast_storage(mx.nd.array(lhs), stype=stype) + rhs_nd = mx.nd.cast_storage(mx.nd.array(rhs), stype=stype) assert_allclose(fn(lhs, rhs), fn(lhs_nd, rhs_nd).asnumpy(), rtol=1e-4, atol=1e-4)