Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improving documentation and error messages for Async distributed training with Gluon #11910

Merged
merged 8 commits into from
Jul 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions docs/faq/distributed_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,23 @@ These can be passed as arguments to the iterator.
You can look at [example/gluon/image_classification.py](https://github.com/apache/incubator-mxnet/blob/master/example/gluon/image_classification.py)
to see an example usage.

### Updating weights
KVStore server supports two modes, one which aggregates the gradients and updates the weights using those gradients, and second where the server only aggregates gradients. In the latter case, when a worker process pulls from kvstore, it gets the aggregated gradients. The worker then uses these gradients and applies the weights locally.

When using Gluon there is an option to choose between these modes by passing `update_on_kvstore` variable when you create the [Trainer](https://mxnet.incubator.apache.org/versions/master/api/python/gluon/gluon.html#mxnet.gluon.Trainer) object like this:

```
trainer = gluon.Trainer(net.collect_params(), optimizer='sgd',
optimizer_params={'learning_rate': opt.lr,
'wd': opt.wd,
'momentum': opt.momentum,
'multi_precision': True},
kvstore=kv,
update_on_kvstore=True)
```

When using the symbolic interface, it performs the weight updates on the server without the user having to do anything special.

### Different Modes of Distributed Training
Distributed training itself is enabled when kvstore creation string contains the word `dist`.

Expand All @@ -86,9 +103,9 @@ In this mode, if a worker crashes, then it halts the progress of all workers.
- `dist_async`: In asynchronous distributed training, the server receives gradients from one worker and immediately updates its store, which it uses to respond to any future pulls.
This means that a worker who finishes processing a batch can pull the current parameters from server and start the next batch,
even if other workers haven't finished processing the earlier batch.
This is faster than `dist_sync` but can take more epochs to converge.
In `async` mode, it is required to pass an optimizer because in the absence of an optimizer kvstore would replace the stored weights with received weights and this doesn't make sense for training in asynchronous mode.
This is faster than `dist_sync` because there is no cost of synchronization, but can take more epochs to converge.
The update of weights is atomic, meaning no two updates happen on the same weight at the same time. However, the order of updates is not guaranteed.
In `async` mode, it is required to pass an optimizer because in the absence of an optimizer kvstore would replace the stored weights with received weights and this doesn't make sense for training in asynchronous mode. Hence, when using Gluon with `async` mode we need to set `update_on_kvstore` to `True`.

- `dist_sync_device`: Same as `dist_sync` except that when there are multiple GPUs being used on each node,
this mode aggregates gradients and updates weights on GPU while dist_sync does so on CPU memory.
Expand Down
8 changes: 7 additions & 1 deletion python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ def _init_kvstore(self):
arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts),
arg_arrays)
if kvstore and 'async' in kvstore.type and config['update_on_kvstore'] is not None\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are forcing the user to set this param, why don't we set it inside the function itself as default value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user does not set that variable explicitly (default way), then I set it to the right value. If the user explicitly sets it to false, then raised the error.

and not config['update_on_kvstore']:
raise ValueError("Please set update_on_kvstore to true "
"when training in async mode.")

if config['update_on_kvstore'] is not None:
update_on_kvstore = config['update_on_kvstore']
if kvstore:
Expand All @@ -195,7 +200,8 @@ def _init_kvstore(self):
self._distributed = 'dist' in kvstore.type
if self._distributed:
# kv.pull(row_sparse_grad) is not supported for dist kvstore
update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad
update_on_kvstore = self._contains_sparse_weight or self._contains_sparse_grad \
or 'async' in kvstore.type
if update_on_kvstore:
# optimizer preferably needs to be set before init for multiprecision
kvstore.set_optimizer(self._optimizer)
Expand Down
48 changes: 48 additions & 0 deletions tests/nightly/dist_async_kvstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# pylint: skip-file
import sys
sys.path.insert(0, "../../python/")
import mxnet as mx

kv = mx.kv.create('dist_async')
my_rank = kv.rank
nworker = kv.num_workers

def test_gluon_trainer_type():
def check_trainer_kv_update(update_on_kv):
params = mx.gluon.ParameterDict()
x = params.get('x', shape=(10,1), lr_mult=1.0)
params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
try:
trainer = mx.gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv, update_on_kvstore=update_on_kv)
trainer._init_kvstore()
assert trainer._kv_initialized
assert trainer._update_on_kvstore is True
except ValueError:
assert update_on_kv is False

check_trainer_kv_update(False)
check_trainer_kv_update(True)
check_trainer_kv_update(None)
print('worker ' + str(my_rank) + ' passed test_gluon_trainer_type')

if __name__ == "__main__":
test_gluon_trainer_type()