Skip to content

Commit

Permalink
support str key type in kvstore (apache#6765)
Browse files Browse the repository at this point in the history
* update kvstore unit test

* update model/module.py

* fix lint

* remove int keys in kvstore

* update cast to str function

* remove _cast_to_str_keys

* fix lint

* always cast to str

Conflicts:
	include/mxnet/c_api.h
	include/mxnet/kvstore.h
	python/mxnet/kvstore.py
	python/mxnet/model.py
	python/mxnet/module/module.py
	src/c_api/c_api.cc
	src/kvstore/kvstore_local.h
	tests/python/unittest/test_kvstore.py
  • Loading branch information
eric-haibin-lin committed Jul 12, 2017
1 parent e4e9e40 commit 955b13d
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 44 deletions.
54 changes: 19 additions & 35 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,27 @@
from .base import NDArrayHandle, KVStoreHandle
from . import optimizer as opt

def _ctype_str_key_value(keys, vals):
names = []
if isinstance(keys, str):
if isinstance(vals, NDArray):
names.append(c_str(keys))
return (c_array(ctypes.c_char_p, names),
c_array(NDArrayHandle, [vals.handle]))
else:
for value in vals:
assert(isinstance(value, NDArray))
return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)),
c_array(NDArrayHandle, [value.handle for value in vals]))
else:
def _ctype_key_value(keys, vals):
if isinstance(keys, (tuple, list)):
assert(len(keys) == len(vals))
for k in keys:
assert(isinstance(k, str))
c_keys = []
c_vals = []
for key, val in zip(keys, vals):
c_key_i, c_val_i = _ctype_str_key_value(key, val)
c_key_i, c_val_i = _ctype_key_value(key, val)
c_keys += c_key_i
c_vals += c_val_i
return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals))

def _cast_to_str_keys(keys):
if isinstance(keys, str):
return keys
if isinstance(keys, int):
return str(keys)
str_keys = []
for key in keys:
str_keys.append(str(key) if isinstance(key, int) else key)
return str_keys
names = []
keys = str(keys)
if isinstance(vals, NDArray):
names.append(c_str(keys))
return (c_array(ctypes.c_char_p, names),
c_array(NDArrayHandle, [vals.handle]))
else:
for value in vals:
assert(isinstance(value, NDArray))
return (c_array(ctypes.c_char_p, [c_str(keys)] * len(vals)),
c_array(NDArrayHandle, [value.handle for value in vals]))

def _updater_wrapper(updater):
"""A wrapper for the user-defined handle."""
Expand Down Expand Up @@ -104,8 +92,7 @@ def init(self, key, value):
>>> keys = ['5', '7', '9']
>>> kv.init(keys, [mx.nd.ones(shape)]*len(keys))
"""
key = _cast_to_str_keys(key)
ckeys, cvals = _ctype_str_key_value(key, value)
ckeys, cvals = _ctype_key_value(key, value)
check_call(_LIB.MXKVStoreInitEx(self.handle, mx_uint(len(ckeys)), ckeys, cvals))

def push(self, key, value, priority=0):
Expand Down Expand Up @@ -165,8 +152,7 @@ def push(self, key, value, priority=0):
[[ 4. 4. 4.]
[ 4. 4. 4.]]
"""
key = _cast_to_str_keys(key)
ckeys, cvals = _ctype_str_key_value(key, value)
ckeys, cvals = _ctype_key_value(key, value)
check_call(_LIB.MXKVStorePushEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))
Expand Down Expand Up @@ -239,8 +225,7 @@ def pull(self, key, out=None, priority=0):
else:
for v in val:
assert(v.storage_type == 'default')
key = _cast_to_str_keys(key)
ckeys, cvals = _ctype_str_key_value(key, out)
ckeys, cvals = _ctype_key_value(key, out)
check_call(_LIB.MXKVStorePullEx(
self.handle, mx_uint(len(ckeys)), ckeys, cvals,
ctypes.c_int(priority)))
Expand Down Expand Up @@ -306,9 +291,8 @@ def row_sparse_pull(self, key, out=None, priority=0, row_ids=None):
else:
for v in val:
assert(v.storage_type == 'row_sparse')
key = _cast_to_str_keys(key)
ckeys, cvals = _ctype_str_key_value(key, out)
_, crow_ids = _ctype_str_key_value(key, row_ids)
ckeys, cvals = _ctype_key_value(key, out)
_, crow_ids = _ctype_key_value(key, row_ids)
assert(len(crow_ids) == len(cvals)), (len(crow_ids), len(cvals))
#TODO(haibin) pickup upstream changes which removed `_cast_to_str_keys`

Expand Down
1 change: 1 addition & 0 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def _initialize_kvstore(kvstore, param_arrays, arg_params, param_names,

def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names,
sparse_pull_dict=None):

"""Perform update of param_arrays from grad_arrays on kvstore."""
for index, pair in enumerate(zip(param_arrays, grad_arrays)):
arg_list, grad_list = pair
Expand Down
15 changes: 6 additions & 9 deletions tests/python/unittest/test_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def init_kv_with_str(stype='default'):
"""init kv """
kv = mx.kv.create()
# single
kv.init('a', mx.nd.zeros(shape))
kv.init('a', mx.nd.zeros(shape, storage_type=stype))
# list
kv.init(str_keys, [mx.nd.zeros(shape=shape, storage_type=stype)] * len(keys))
return kv
Expand Down Expand Up @@ -80,7 +80,6 @@ def check_init(kv, key):
check_init(mx.kv.create(), 3)
check_init(mx.kv.create(), 'a')


def test_list_kv_pair():
"""list key-value pair push & pull"""
def check_list_kv_pair(kv, key):
Expand Down Expand Up @@ -128,7 +127,7 @@ def test_sparse_aggregator():
"""aggregate sparse ndarray on muliple devices"""

stype = 'row_sparse'
kv = init_kv(stype)
kv = init_kv_with_str(stype)

# devices
num_devs = 4
Expand All @@ -142,8 +141,8 @@ def test_sparse_aggregator():

# prepare row_ids
all_rows = mx.nd.array(np.arange(shape[0]), dtype='int64')
kv.push(3, vals)
kv.row_sparse_pull(3, out=vals, row_ids=[all_rows] * len(vals))
kv.push('a', vals)
kv.row_sparse_pull('a', out=vals, row_ids=[all_rows] * len(vals))
result_sum = np.zeros(shape)
for v in vals:
result_sum += v.asnumpy()
Expand All @@ -155,15 +154,14 @@ def test_sparse_aggregator():
for v in vals[0]:
expected_sum += v.asnumpy()

kv.push(keys, vals)
kv.row_sparse_pull(keys, out=vals, row_ids=[[all_rows] * num_devs] * len(vals))
kv.push(str_keys, vals)
kv.row_sparse_pull(str_keys, out=vals, row_ids=[[all_rows] * num_devs] * len(vals))
for vv in vals:
result_sum = np.zeros(shape)
for v in vv:
result_sum += v.asnumpy()
assert_almost_equal(result_sum, expected_sum * num_devs)


def updater(key, recv, local):
"""use updater: +="""
local += recv
Expand Down Expand Up @@ -208,7 +206,6 @@ def check_updater(kv, key, key_list):
check_updater(str_kv, 'a', str_keys)



def test_get_type():
kvtype = 'local_allreduce_cpu'
kv = mx.kv.create(kvtype)
Expand Down

0 comments on commit 955b13d

Please sign in to comment.