Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Optimizer docs, update quickstart to use Optimizer #416

Merged
merged 23 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions docs/modules/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

.. raw:: html

This module contains decentralized optimizers that wrap regular pytorch optimizers to collaboratively train a shared model. Depending on the exact type, optimizer may average model parameters with peers, exchange gradients, or follow a more complicated distributed training strategy.
This module contains decentralized optimizers that wrap your regular PyTorch Optimizer to train with peers.
Depending on the exact configuration, Optimizer may perform large synchronous updates equivalent,
or perform asynchrnous local updates and average model parameters.

<br><br>

.. automodule:: hivemind.optim.experimental.optimizer
Expand All @@ -13,7 +16,7 @@
----------------------

.. autoclass:: Optimizer
:members: step, zero_grad, load_state_from_peers, param_groups, shutdown
:members: step, local_epoch, zero_grad, load_state_from_peers, param_groups, shutdown
:member-order: bysource

.. currentmodule:: hivemind.optim.grad_scaler
Expand All @@ -24,13 +27,6 @@
**CollaborativeOptimizer**
--------------------------

.. raw:: html

CollaborativeOptimizer is a legacy version of hivemind.Optimizer. **For new projects, please use hivemind.Optimizer.**
Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and then some.
CollaborativeOptimizer will still be supported for awhile, but will eventually be deprecated.
<br><br>


.. automodule:: hivemind.optim.collaborative
.. currentmodule:: hivemind.optim
Expand Down
58 changes: 30 additions & 28 deletions docs/user/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,27 @@ model = nn.Sequential(nn.Conv2d(3, 16, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(),
nn.Flatten(), nn.Linear(32 * 5 * 5, 10))
opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


# Create DHT: a decentralized key-value storage shared between peers
dht = hivemind.DHT(start=True)
print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()])

# Set up a decentralized optimizer that will average with peers in background
opt = hivemind.optim.DecentralizedOptimizer(
opt, # wrap the SGD optimizer defined above
dht, # use a DHT that is connected with other peers
average_parameters=True, # periodically average model weights in opt.step
average_gradients=False, # do not average accumulated gradients
prefix='my_cifar_run', # unique identifier of this collaborative run
target_group_size=16, # maximum concurrent peers for this run
opt = hivemind.Optimizer(
dht=dht, # use a DHT that is connected with other peers
run_id='my_cifar_run', # unique identifier of this collaborative run
optimizer=opt, # wrap the SGD optimizer defined above
use_local_updates=True, # perform optimizer steps with local gradients, average parameters in background
batch_size_per_step=32, # each call to opt.step adds this many samples towards the next epoch
target_batch_size=10000, # move to next epoch after peers collectively process this many samples
matchmaking_time=3.0, # when averaging parameters, gather peers in background for up to this many seconds
averaging_timeout=10.0, # give up on averaging if not successful in this many seconds
verbose=True # print logs incessently
)
# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created

# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
with tqdm() as progressbar:
while True:
for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=256):
for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=32):
opt.zero_grad()
loss = F.cross_entropy(model(x_batch), y_batch)
loss.backward()
Expand All @@ -78,15 +79,15 @@ with tqdm() as progressbar:


As you can see, this code is regular PyTorch with one notable exception: it wraps your regular optimizer with a
`DecentralizedOptimizer`. This optimizer uses `DHT` to find other peers and tries to exchange weights them. When you run
`hivemind.Optimizer`. This optimizer uses `DHT` to find other peers and tries to exchange parameters them. When you run
the code (please do so), you will see the following output:

```shell
To join the training, use initial_peers = ['/ip4/127.0.0.1/tcp/XXX/p2p/YYY']
[...] Starting a new averaging round with current parameters.
```

This is `DecentralizedOptimizer` telling you that it's looking for peers. Since there are no peers, we'll need to create
This is `hivemind.Optimizer` telling you that it's looking for peers. Since there are no peers, we'll need to create
them ourselves.

Copy the entire script (or notebook) and modify this line:
Expand Down Expand Up @@ -123,26 +124,28 @@ model = nn.Sequential(nn.Conv2d(3, 16, (5, 5)), nn.MaxPool2d(2, 2), nn.ReLU(),
opt = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Create DHT: a decentralized key-value storage shared between peers
dht = hivemind.DHT(initial_peers=[COPY_FROM_ANOTHER_PEER_OUTPUTS], start=True)
dht = hivemind.DHT(initial_peers=[COPY_FROM_OTHER_PEERS_OUTPUTS], start=True)
print("To join the training, use initial_peers =", [str(addr) for addr in dht.get_visible_maddrs()])

# Set up a decentralized optimizer that will average with peers in background
opt = hivemind.optim.DecentralizedOptimizer(
opt, # wrap the SGD optimizer defined above
dht, # use a DHT that is connected with other peers
average_parameters=True, # periodically average model weights in opt.step
average_gradients=False, # do not average accumulated gradients
prefix='my_cifar_run', # unique identifier of this collaborative run
target_group_size=16, # maximum concurrent peers for this run
opt = hivemind.Optimizer(
dht=dht, # use a DHT that is connected with other peers
run_id='my_cifar_run', # unique identifier of this collaborative run
optimizer=opt, # wrap the SGD optimizer defined above
use_local_updates=True, # perform optimizer steps with local gradients, average parameters in background
batch_size_per_step=32, # each call to opt.step adds this many samples towards the next epoch
target_batch_size=10000, # move to next epoch after all peers collectively process this many samples
matchmaking_time=3.0, # when averaging parameters, gather peers in background for up to this many seconds
averaging_timeout=10.0, # give up on averaging if not successful in this many seconds
verbose=True # print logs incessently
)

opt.averager.load_state_from_peers()
opt.load_state_from_peers()

# Note: if you intend to use GPU, switch to it only after the decentralized optimizer is created
with tqdm() as progressbar:
while True:
for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=256):
for x_batch, y_batch in torch.utils.data.DataLoader(trainset, shuffle=True, batch_size=32):
opt.zero_grad()
loss = F.cross_entropy(model(x_batch), y_batch)
loss.backward()
Expand All @@ -166,17 +169,16 @@ This message means that the optimizer has averaged model parameters with another
during one of the calls to `opt.step()`. You can start more peers by replicating the same code as the second peer,
using either the first or second peer as `initial_peers`.

The only issue with this code is that each new peer starts with a different untrained network blends its un-trained
parameters with other peers, reseting their progress. You can see this effect as a spike increase in training loss
immediately after new peer joins training. To avoid this problem, the second peer can download the
current model/optimizer state from an existing peer right before it begins training on minibatches:
Each new peer starts with an untrained network and must download the latest training state before it can contribute.
By default, peer will automatically detect that it is out of sync and start ``Downloading parameters from peer <...>``.
To avoid wasting the first optimizer step, one can manually download the latest model/optimizer state right before it begins training on minibatches:
```python
opt.averager.load_state_from_peers()
opt.load_state_from_peers()
```

Congrats, you've just started a pocket-sized experiment with decentralized deep learning!

However, this is just the bare minimum of what hivemind can do. In [this example](https://github.com/learning-at-home/hivemind/tree/master/examples/albert),
However, this is only the basics of what hivemind can do. In [this example](https://github.com/learning-at-home/hivemind/tree/master/examples/albert),
we show how to use a more advanced version of DecentralizedOptimizer to collaboratively train a large Transformer over the internet.
Copy link
Member

@borzunov borzunov Dec 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace DecentralizedOptimizer -> hivemind.Optimizer here and in L186.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


If you want to learn more about each individual component,
Copy link
Member

@borzunov borzunov Dec 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace:

  • (Li et al. 2020) -> Li et al. (2020)
  • (Ryabinin et al. 2021) -> Ryabinin et al. (2021)

Expand Down
3 changes: 3 additions & 0 deletions hivemind/compression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def estimate_compression_ratio(self, info: CompressionInfo) -> float:
"""Estimate the compression ratio without doing the actual compression; lower ratio = better compression"""
...

def __repr__(self):
return f"hivemind.{self.__class__.__name__}()"


class NoCompression(CompressionBase):
"""A dummy compression strategy that preserves the original tensor as is."""
Expand Down
4 changes: 4 additions & 0 deletions hivemind/optim/collaborative.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class TrainingProgressSchema(BaseModel):

class CollaborativeOptimizer(DecentralizedOptimizerBase):
"""
:note: **For new projects please use hivemind.Optimizer**. CollaborativeOptimizer is an older version of that.
Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and then some.
CollaborativeOptimizer will still be supported for a while, but eventually it will be deprecated.

An optimizer that performs model updates after collaboratively accumulating a target (large) batch size across peers

These optimizers use DHT to track how much progress did the collaboration make towards target batch size.
Expand Down
Loading