Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CollaborativeOptimizer #215

Merged
merged 37 commits into from
Apr 10, 2021
Merged

Add CollaborativeOptimizer #215

merged 37 commits into from
Apr 10, 2021

Conversation

leshanbog
Copy link
Collaborator

This PR introduces necessary objects for collaborative training so that client code for participation is clean and concise (like default train loop).

  • WeightedAverager - DecentralizedAverager that averages trainable params or gradients with peer-wise weights
  • CollaborativeOptimizer - performs model updates after collaboratively accumulating a target (large) batch size across peers

@justheuristic justheuristic requested review from nevec and foksly April 8, 2021 09:45
@mryab mryab self-requested a review April 8, 2021 11:56
@justheuristic
Copy link
Member

justheuristic commented Apr 8, 2021

[example usage]

import time, socket
import torch, torch.nn as nn
import hivemind

with socket.socket() as sock:
    coordinator_exists = sock.connect_ex(("127.0.0.1", 1337)) == 0
if not coordinator_exists:
    dht = hivemind.DHT(listen_on='127.0.0.1:1337', start=True),
    
model = nn.Sequential(nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, 32))

opt = hivemind.CollaborativeOptimizer(
    opt=torch.optim.Adam(model.parameters()),
    dht=hivemind.DHT(initial_peers=['127.0.0.1:1337'], start=True),
    prefix='test_exp', target_group_size=2,
    target_batch_size=32, batch_size_per_step=1, verbose=True,
    start=True
)


while True:
    x = torch.randn(10, 32)
    time.sleep(1)
    loss = torch.mean((x - model(x)) ** 2)
    loss.backward()
    opt.step()

@mryab mryab changed the title Collaborative averager Add CollaborativeOptimizer Apr 9, 2021
Copy link
Member

@mryab mryab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale for all logger.debug replacements is that we don't want to pollute the output of training scripts unless asked by the user. An alternative solution is to control the verbosity by an argument

min_refresh_period, max_refresh_period, default_refresh_period
self.expected_drift_peers, self.expected_drift_rate = expected_drift_peers, expected_drift_rate
self.averaging_timeout, self.metadata_expiration = averaging_timeout, metadata_expiration
self.averager = TrainingAverager(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should accept averager class in constructor or give a possibility to override averager in subclass?
I don't see an easy way to change averaging logic (e.g. average opt stats) without copypasting the whole class. Correct me if I'm wrong

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, that might prove useful in subsequent research for secure/robust averagers
/* Added averager_cls parameter that allows user to override the averager class */

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conversation log with @nevec : we have instead opted for a _make_averager method for more versatility

@justheuristic justheuristic merged commit 7bb6565 into master Apr 10, 2021
@justheuristic justheuristic deleted the collaborative_averager branch April 10, 2021 19:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants