diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 314ec1f3d08f..1fced8dd452d 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -11,39 +11,27 @@ from .base import NDArrayHandle, KVStoreHandle from . import optimizer as opt -def _ctype_str_key_value(keys, vals): - names = [] - if isinstance(keys, str): - if isinstance(vals, NDArray): - names.append(c_str(keys)) - return (c_array(ctypes.c_char_p, names), - c_array(NDArrayHandle, [vals.handle])) - else: - for value in vals: - assert(isinstance(value, NDArray)) - return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)), - c_array(NDArrayHandle, [value.handle for value in vals])) - else: +def _ctype_key_value(keys, vals): + if isinstance(keys, (tuple, list)): assert(len(keys) == len(vals)) - for k in keys: - assert(isinstance(k, str)) c_keys = [] c_vals = [] for key, val in zip(keys, vals): - c_key_i, c_val_i = _ctype_str_key_value(key, val) + c_key_i, c_val_i = _ctype_key_value(key, val) c_keys += c_key_i c_vals += c_val_i return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals)) - -def _cast_to_str_keys(keys): - if isinstance(keys, str): - return keys - if isinstance(keys, int): - return str(keys) - str_keys = [] - for key in keys: - str_keys.append(str(key) if isinstance(key, int) else key) - return str_keys + names = [] + keys = str(keys) + if isinstance(vals, NDArray): + names.append(c_str(keys)) + return (c_array(ctypes.c_char_p, names), + c_array(NDArrayHandle, [vals.handle])) + else: + for value in vals: + assert(isinstance(value, NDArray)) + return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)), + c_array(NDArrayHandle, [value.handle for value in vals])) def _updater_wrapper(updater): """A wrapper for the user-defined handle.""" @@ -104,8 +92,7 @@ def init(self, key, value): >>> keys = ['5', '7', '9'] >>> kv.init(keys, [mx.nd.ones(shape)]*len(keys)) """ - key = _cast_to_str_keys(key) - ckeys, cvals = _ctype_str_key_value(key, value) + ckeys, cvals = _ctype_key_value(key, value) check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals)) def push(self, key, value, priority=0): @@ -165,8 +152,7 @@ def push(self, key, value, priority=0): [[ 4. 4. 4.] [ 4. 4. 4.]] """ - key = _cast_to_str_keys(key) - ckeys, cvals = _ctype_str_key_value(key, value) + ckeys, cvals = _ctype_key_value(key, value) check_call(_LIB.MXKVStorePushEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) @@ -239,8 +225,7 @@ def pull(self, key, out=None, priority=0): else: for v in val: assert(v.storage_type == 'default') - key = _cast_to_str_keys(key) - ckeys, cvals = _ctype_str_key_value(key, out) + ckeys, cvals = _ctype_key_value(key, out) check_call(_LIB.MXKVStorePullEx( self.handle, mx_uint(len(ckeys)), ckeys, cvals, ctypes.c_int(priority))) @@ -306,9 +291,8 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None): else: for v in val: assert(v.storage_type == 'row_sparse') - key = _cast_to_str_keys(key) - ckeys, cvals = _ctype_str_key_value(key, out) - _, crow_ids = _ctype_str_key_value(key, row_ids) + ckeys, cvals = _ctype_key_value(key, out) + _, crow_ids = _ctype_key_value(key, row_ids) assert(len(crow_ids) == len(cvals)), (len(crow_ids), len(cvals)) #TODO(haibin) pickup upstream changes which removed `_cast_to_str_keys` diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 0ecd0224d1da..ae0200fba732 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -92,6 +92,7 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names, def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names, sparse_pull_dict=None): + """Perform update of param_arrays from grad_arrays on kvstore.""" for index, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index 17a227a7604f..a0dd870d3106 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -20,7 +20,7 @@ def init_kv_with_str(stype='default'): """init kv """ kv = mx.kv.create() # single - kv.init('a', mx.nd.zeros(shape)) + kv.init('a', mx.nd.zeros(shape, storage_type=stype)) # list kv.init(str_keys, [mx.nd.zeros(shape=shape, storage_type=stype)] * len(keys)) return kv @@ -80,7 +80,6 @@ def check_init(kv, key): check_init(mx.kv.create(), 3) check_init(mx.kv.create(), 'a') - def test_list_kv_pair(): """list key-value pair push & pull""" def check_list_kv_pair(kv, key): @@ -128,7 +127,7 @@ def test_sparse_aggregator(): """aggregate sparse ndarray on muliple devices""" stype = 'row_sparse' - kv = init_kv(stype) + kv = init_kv_with_str(stype) # devices num_devs = 4 @@ -142,8 +141,8 @@ def test_sparse_aggregator(): # prepare row_ids all_rows = mx.nd.array(np.arange(shape[0]), dtype='int64') - kv.push(3, vals) - kv.row_sparse_pull(3, out=vals, row_ids=[all_rows] * len(vals)) + kv.push('a', vals) + kv.row_sparse_pull('a', out=vals, row_ids=[all_rows] * len(vals)) result_sum = np.zeros(shape) for v in vals: result_sum += v.asnumpy() @@ -155,15 +154,14 @@ def test_sparse_aggregator(): for v in vals[0]: expected_sum += v.asnumpy() - kv.push(keys, vals) - kv.row_sparse_pull(keys, out=vals, row_ids=[[all_rows] * num_devs] * len(vals)) + kv.push(str_keys, vals) + kv.row_sparse_pull(str_keys, out=vals, row_ids=[[all_rows] * num_devs] * len(vals)) for vv in vals: result_sum = np.zeros(shape) for v in vv: result_sum += v.asnumpy() assert_almost_equal(result_sum, expected_sum * num_devs) - def updater(key, recv, local): """use updater: +=""" local += recv @@ -208,7 +206,6 @@ def check_updater(kv, key, key_list): check_updater(str_kv, 'a', str_keys) - def test_get_type(): kvtype = 'local_allreduce_cpu' kv = mx.kv.create(kvtype)