Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
reduce a copy for rowsparse parameter.reduce (#12039)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin authored Aug 10, 2018
1 parent 5a9c3af commit 6f7dee0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 20 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _reduce(self):
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx)
data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx)
self._trainer._row_sparse_pull(self, data, all_row_ids)
self._trainer._row_sparse_pull(self, data, all_row_ids, full_idx=True)
return data

def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
Expand Down
11 changes: 9 additions & 2 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,21 @@ def set_learning_rate(self, lr):
else:
self._optimizer.set_learning_rate(lr)

def _row_sparse_pull(self, parameter, out, row_id):
def _row_sparse_pull(self, parameter, out, row_id, full_idx=False):
"""Internal method to invoke pull operations on KVStore. If `full_idx` is set to True,
`kv.pull` is preferred instead of `kv.row_sparse_pull`.
"""
# initialize kv and params if not already
if not self._kv_initialized:
self._init_kvstore()
if self._params_to_init:
self._init_params()
idx = self._param2idx[parameter.name]
self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx)
if full_idx and 'dist' not in self._kvstore.type:
assert row_id.size == out.shape[0]
self._kvstore.pull(idx, out=out, priority=-idx, ignore_sparse=False)
else:
self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx)

def step(self, batch_size, ignore_stale_grad=False):
"""Makes one step of parameter update. Should be called after
Expand Down
35 changes: 18 additions & 17 deletions tests/python/unittest/test_gluon_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,24 @@ def test_trainer_save_load():
# check if parameter dict is correctly associated with optimizer after load_state
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2

@with_seed()
def test_trainer_sparse_save_load():
x = gluon.Parameter('x', shape=(10, 1), lr_mult=1.0, stype='row_sparse')
x.initialize(ctx=[mx.cpu(0)], init='zeros')
trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0))
with mx.autograd.record():
for w in x.list_row_sparse_data(all_rows):
y = w * 1
y.backward()
trainer.step(1)
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
trainer.save_states('test_trainer_sparse_save_load.states')
trainer.load_states('test_trainer_sparse_save_load.states')
x.lr_mult = 2.0
# check if parameter dict is correctly associated with optimizer after load_state
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2

@with_seed()
def test_trainer_multi_layer_init():
class Net(gluon.Block):
Expand Down Expand Up @@ -158,23 +176,6 @@ def check_init(ctxes):
check_init([mx.cpu(1), mx.cpu(2)])
check_init([mx.cpu(1)])

@with_seed()
def test_trainer_save_load():
x = gluon.Parameter('x', shape=(10,), lr_mult=1.0)
x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1})
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.1
trainer.save_states('test_trainer_save_load.states')
trainer.load_states('test_trainer_save_load.states')
x.lr_mult = 2.0
# check if parameter dict is correctly associated with optimizer after load_state
assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2

@with_seed()
def test_trainer_reset_kv():
def check_trainer_reset_kv(kv):
Expand Down

0 comments on commit 6f7dee0

Please sign in to comment.