Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement core functionality of hivemind.Optimizer #403

Merged
merged 146 commits into from
Nov 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
146 commits
Select commit Hold shift + click to select a range
f34d02d
[WIP] main hivemind.Optimizer
justheuristic Nov 18, 2021
6480775
Merge branch 'master' into hivemind_optimizer_thirdtimesthecharm
justheuristic Nov 18, 2021
f69b32a
Merge remote-tracking branch 'origin/master' into hivemind_optimizer_…
justheuristic Nov 18, 2021
500715b
test
justheuristic Nov 18, 2021
033d62e
black-isort
justheuristic Nov 18, 2021
7bd28fd
add vision
justheuristic Nov 18, 2021
a9ace63
add vision
justheuristic Nov 18, 2021
bb39fbd
review
mryab Nov 18, 2021
78c59bd
review
justheuristic Nov 21, 2021
eb7e3f5
docstring
justheuristic Nov 21, 2021
9554720
rename to params
mryab Nov 21, 2021
9b68dad
black-isort
justheuristic Nov 21, 2021
59df40c
isort
justheuristic Nov 21, 2021
36abd5b
import-through
justheuristic Nov 21, 2021
b6a5757
param_groups
justheuristic Nov 21, 2021
5be0e09
remove debug logs
justheuristic Nov 21, 2021
7cf26e8
change verification order
justheuristic Nov 21, 2021
06c2cd0
prefetch=1
justheuristic Nov 21, 2021
749ab83
add closure support (suggested by @SeanNaren)
justheuristic Nov 22, 2021
0add448
add closure support (suggested by @SeanNaren)
justheuristic Nov 22, 2021
8632606
update GradSCaler
justheuristic Nov 22, 2021
53ddbf9
import-through
justheuristic Nov 22, 2021
32dd570
black
justheuristic Nov 22, 2021
b57f8e8
black
justheuristic Nov 22, 2021
6f4d164
black
justheuristic Nov 22, 2021
b7e1c94
isort
justheuristic Nov 22, 2021
e70e02f
isort
justheuristic Nov 22, 2021
c886472
undo scaler changes
justheuristic Nov 22, 2021
0f2cd94
undo scaler changes
justheuristic Nov 22, 2021
df9d80d
undo scaler changes
justheuristic Nov 22, 2021
7cd913d
actually unscale
justheuristic Nov 22, 2021
6acfaff
actually unscale
justheuristic Nov 22, 2021
844f63e
actually unscale
justheuristic Nov 22, 2021
dc76d68
actually unscale
justheuristic Nov 22, 2021
4932739
actually unscale
justheuristic Nov 22, 2021
6f063c6
actually unscale
justheuristic Nov 22, 2021
80ced17
actually unscale
justheuristic Nov 22, 2021
6b05812
try-finally
justheuristic Nov 22, 2021
5d5d688
try-finally
justheuristic Nov 22, 2021
59bcda1
rlock it
justheuristic Nov 22, 2021
94038bf
rlock it
justheuristic Nov 22, 2021
b35b6f6
rlock it
justheuristic Nov 22, 2021
0f2b356
update separately
justheuristic Nov 22, 2021
8c2febd
defer update to user
justheuristic Nov 22, 2021
08f5562
fix
justheuristic Nov 22, 2021
ac8fb55
load params into offloaded optimizer
justheuristic Nov 22, 2021
3fe3809
lock tensors
justheuristic Nov 22, 2021
b656f15
use local accumulators
justheuristic Nov 22, 2021
818a3a5
black-isort
justheuristic Nov 22, 2021
cd6a921
tests
justheuristic Nov 22, 2021
6d1a866
typo
justheuristic Nov 22, 2021
ac5e6cc
lock tensors
justheuristic Nov 22, 2021
da21ae3
lock tensors
justheuristic Nov 22, 2021
4e4343d
apply optimizer results faster
justheuristic Nov 22, 2021
cf57c29
apply to local parameters
justheuristic Nov 22, 2021
dab1cb2
less sleepy
justheuristic Nov 23, 2021
d241faf
Merge branch 'master' into hivemind_optimizer_thirdtimesthecharm
justheuristic Nov 23, 2021
6dc0045
hopefully fix freezing, add averaging frequency
justheuristic Nov 24, 2021
21a2a57
hopefully fix freezing, add averaging frequency
justheuristic Nov 24, 2021
5c32fd2
Merge remote-tracking branch 'origin/hivemind_optimizer_thirdtimesthe…
justheuristic Nov 24, 2021
db8607d
fix offloading lock
justheuristic Nov 24, 2021
3ff6c59
aux peers
justheuristic Nov 24, 2021
69d0a94
review
borzunov Nov 24, 2021
408e926
typo
justheuristic Nov 24, 2021
a45edcb
lengthy comment
justheuristic Nov 24, 2021
e46d53b
hopefully handle shutdown correctly
justheuristic Nov 24, 2021
e5d2b16
and now its black
justheuristic Nov 24, 2021
e924917
check for exact synchronization once per step
justheuristic Nov 25, 2021
a4f0f7b
check for exact synchronization once per step
justheuristic Nov 25, 2021
49c19f6
trick asyncio into submission
justheuristic Nov 25, 2021
7ed6f6d
trick asyncio into submission
justheuristic Nov 25, 2021
867af33
Merge remote-tracking branch 'origin/hivemind_optimizer_thirdtimesthe…
justheuristic Nov 25, 2021
dd7ed55
black-isort
justheuristic Nov 25, 2021
1dd5cd0
option to await a trigger
justheuristic Nov 25, 2021
cac9c5b
option to await a trigger
justheuristic Nov 25, 2021
581155a
option to await a trigger
justheuristic Nov 25, 2021
ec05f8b
[previous commit is verified as stable] implement delayed averaging
justheuristic Nov 25, 2021
b55451d
comment no longer valid
justheuristic Nov 25, 2021
242ef7f
pre-schedule state averaging
justheuristic Nov 25, 2021
42d8d63
black it
justheuristic Nov 25, 2021
a48a62a
switch to using new tracker interface
justheuristic Nov 25, 2021
c196b9b
black-isort
justheuristic Nov 25, 2021
49c1443
better defaults
justheuristic Nov 25, 2021
3766052
clarify end of step message
justheuristic Nov 25, 2021
6ead2dc
debug delay_grad_averaging into submission
justheuristic Nov 25, 2021
a3e900e
debug delay_grad_averaging into submission
justheuristic Nov 25, 2021
4545f31
Merge remote-tracking branch 'origin/master' into hivemind_optimizer_…
justheuristic Nov 25, 2021
d4b13e0
use averaging priority
justheuristic Nov 25, 2021
6990b16
black-isort
justheuristic Nov 25, 2021
4679e32
typo
justheuristic Nov 25, 2021
4aa38e7
PR attribution
foksly Nov 25, 2021
bf7a208
PR attribution (from CollaborativeOptimizer)
leshanbog Nov 25, 2021
da440f1
PR attribution (from aux peers)
yhn112 Nov 25, 2021
3a5ba37
black-isort
justheuristic Nov 25, 2021
8f7810e
reduce test duration
justheuristic Nov 25, 2021
561a0df
black-isort
justheuristic Nov 25, 2021
93b72c3
forked
justheuristic Nov 25, 2021
3c645ed
Update tests/test_optimizer.py
justheuristic Nov 25, 2021
9c2d4e4
Update hivemind/optim/experimental/progress_tracker.py
justheuristic Nov 25, 2021
dfec847
review
borzunov Nov 25, 2021
760e103
Merge remote-tracking branch 'origin/hivemind_optimizer_thirdtimesthe…
justheuristic Nov 25, 2021
6a9e2e8
review
justheuristic Nov 26, 2021
5a8fd7d
Update hivemind/optim/experimental/optimizer.py
justheuristic Nov 26, 2021
7c566e9
Update hivemind/optim/experimental/optimizer.py
justheuristic Nov 26, 2021
92f21dd
Update hivemind/optim/experimental/optimizer.py
justheuristic Nov 26, 2021
bd93881
Update benchmarks/benchmark_optimizer.py
justheuristic Nov 26, 2021
b9e8db1
review
borzunov Nov 26, 2021
dcc8e0a
Merge remote-tracking branch 'origin/hivemind_optimizer_thirdtimesthe…
justheuristic Nov 26, 2021
4921097
accelerate test
justheuristic Nov 26, 2021
2738f8c
prefix -> run_id
borzunov Nov 26, 2021
8f1186b
fluff
justheuristic Nov 26, 2021
754d3f9
un-flappy
justheuristic Nov 26, 2021
2a15c37
local updates and docstring
justheuristic Nov 27, 2021
1a6aa77
black-isort
justheuristic Nov 27, 2021
56bdc5c
ensure gpu compat
justheuristic Nov 27, 2021
3fd11f1
init averaged parameters
justheuristic Nov 27, 2021
3aa56b3
init averaged parameters
justheuristic Nov 27, 2021
024b339
force_copy
justheuristic Nov 27, 2021
f19150a
[fix from running mingpt]
justheuristic Nov 27, 2021
c60299d
fix for num_peers=1
justheuristic Nov 27, 2021
4b915ee
[final] adjust state scheduling
justheuristic Nov 27, 2021
4936f22
[final] adjust state scheduling
justheuristic Nov 27, 2021
6377ea7
better shutdown script
justheuristic Nov 27, 2021
cc17c69
cosmetics
justheuristic Nov 27, 2021
b43fde3
cosmetics
justheuristic Nov 27, 2021
744ba1a
warn on deleting future that requires trigger
justheuristic Nov 27, 2021
fe08462
prefix -> run_id
justheuristic Nov 27, 2021
f4e07fe
[benchmarking WIP] avoid deleting non-triggered StepControl
justheuristic Nov 27, 2021
9d41ebf
[benchmarking finished] avoid accidentally cancelling averaging rounds
justheuristic Nov 27, 2021
b2be06c
[benchmarking finished] avoid accidentally cancelling averaging rounds
justheuristic Nov 27, 2021
a31c669
Merge remote-tracking branch 'origin/hivemind_optimizer_thirdtimesthe…
justheuristic Nov 27, 2021
55e85e5
comment typo
justheuristic Nov 27, 2021
8b3a69d
[benchmarking #2] reduce waiting time
justheuristic Nov 27, 2021
af50253
remove min_matchmaking_time
justheuristic Nov 27, 2021
13904c3
reduce diff
justheuristic Nov 27, 2021
9bdf6c7
hopefully final update
justheuristic Nov 28, 2021
ffdec66
Avoid unnecessary lock on averaged tensors
justheuristic Nov 28, 2021
076b1f6
minor bugfix with all-reduce scheduling
justheuristic Nov 28, 2021
d655b74
handle cancellation before trigger
justheuristic Nov 28, 2021
8840aaa
notify peers if averaging round while awaiting trigger
justheuristic Nov 28, 2021
aa2b70e
unironiously black
justheuristic Nov 28, 2021
a22b6f2
use status_loglevel for loggign shutdown phases
justheuristic Nov 28, 2021
96e248f
[diff review] fix random typo
justheuristic Nov 28, 2021
6beee92
Update RTFD
justheuristic Nov 28, 2021
15d71a3
docs
justheuristic Nov 28, 2021
f0160ce
docs
justheuristic Nov 28, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions benchmarks/benchmark_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import multiprocessing as mp
import random
import time
from contextlib import nullcontext
from dataclasses import dataclass
from functools import partial
from typing import Callable

import numpy as np
import torch
import torchvision
from torch import nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset

import hivemind
from hivemind.optim.experimental.optimizer import Optimizer
from hivemind.utils.crypto import RSAPrivateKey


@dataclass(frozen=True)
class TrainingArguments:
seed: int = 42
run_id: str = "my_exp"

num_peers: int = 8
num_clients: int = 3
target_batch_size: int = 256
reuse_grad_buffers: bool = True
delay_grad_averaging: bool = True
delay_optimizer_step: bool = True
average_state_every: int = 1
use_amp: bool = False

lr_base: float = 0.1
lr_gamma: int = 0.1
lr_step_size: int = 10
max_epoch: int = 25

batch_size_min: int = 2
batch_size_max: int = 16
batch_time_min: float = 1.0
batch_time_max: float = 4.5
batch_time_std: float = 0.5

matchmaking_time: float = 5.0
max_refresh_period: float = 5.0
averaging_timeout: float = 15.0
winddown_time: float = 5.0
verbose: bool = True

device: str = "cpu"
make_dataset: Callable[[], Dataset] = lambda: torchvision.datasets.MNIST(train=True, root=".", download=True)
make_model: Callable[[int, int], nn.Module] = lambda num_features, num_classes: nn.Sequential(
nn.Linear(num_features, 64), nn.ReLU(), nn.Linear(64, num_classes)
)


def benchmark_optimizer(args: TrainingArguments):
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.set_num_threads(1)

dht = hivemind.DHT(start=True)

train_dataset = args.make_dataset()
num_features = train_dataset.data[0].numel()
num_classes = len(train_dataset.classes)
X_train = torch.as_tensor(train_dataset.data, dtype=torch.float32)
X_train = X_train.sub_(X_train.mean((0, 1, 2))).div_(X_train.std((0, 1, 2))).reshape((-1, num_features))
y_train = torch.as_tensor(train_dataset.targets, dtype=torch.int64)
del train_dataset

def run_trainer(batch_size: int, batch_time: float, client_mode: bool, verbose: bool):
model = args.make_model(num_features, num_classes).to(args.device)

assert isinstance(model, torch.nn.Module), "model_arch must evaluate to a pytorch module"

optimizer = Optimizer(
run_id=args.run_id,
target_batch_size=args.target_batch_size,
batch_size_per_step=batch_size,
params=model.parameters(),
optimizer=partial(torch.optim.SGD, lr=args.lr_base),
scheduler=partial(torch.optim.lr_scheduler.StepLR, gamma=args.lr_gamma, step_size=args.lr_step_size),
dht=hivemind.DHT(initial_peers=dht.get_visible_maddrs(), client_mode=client_mode, start=True),
tracker_opts=dict(private_key=RSAPrivateKey(), max_refresh_period=args.max_refresh_period),
matchmaking_time=args.matchmaking_time,
averaging_timeout=args.averaging_timeout,
reuse_grad_buffers=args.reuse_grad_buffers,
delay_grad_averaging=args.delay_grad_averaging,
delay_optimizer_step=args.delay_optimizer_step,
average_state_every=args.average_state_every,
client_mode=client_mode,
verbose=verbose,
)

if args.use_amp and args.reuse_grad_buffers:
grad_scaler = hivemind.GradScaler()
else:
# check that hivemind.Optimizer supports regular PyTorch grad scaler as well
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)

prev_time = time.perf_counter()

while optimizer.local_epoch < args.max_epoch:
time.sleep(max(0.0, prev_time + random.gauss(batch_time, args.batch_time_std) - time.perf_counter()))

batch = torch.randint(0, len(X_train), (batch_size,))

with torch.cuda.amp.autocast() if args.use_amp else nullcontext():
loss = F.cross_entropy(model(X_train[batch].to(args.device)), y_train[batch].to(args.device))
grad_scaler.scale(loss).backward()

grad_scaler.unscale_(optimizer)

if args.use_amp:
grad_scaler.step(optimizer)
else:
optimizer.step()

grad_scaler.update()

if not args.reuse_grad_buffers:
optimizer.zero_grad()

prev_time = time.perf_counter()

time.sleep(args.winddown_time)
optimizer.shutdown()

peers = []

for index in range(args.num_peers):
batch_size = random.randint(args.batch_size_min, args.batch_size_max)
batch_time = random.uniform(args.batch_time_min, args.batch_time_max)
peers.append(
mp.Process(
target=run_trainer,
name=f"trainer-{index}",
daemon=False,
kwargs=dict(
batch_size=batch_size,
batch_time=batch_time,
client_mode=(index >= args.num_peers - args.num_clients),
verbose=args.verbose and (index == 0),
),
)
)

try:
for peer in peers[1:]:
peer.start()
peers[0].run()
for peer in peers[1:]:
peer.join()
finally:
for peer in peers[1:]:
peer.kill()


if __name__ == "__main__":
benchmark_optimizer(TrainingArguments())
32 changes: 29 additions & 3 deletions docs/modules/optim.rst
Original file line number Diff line number Diff line change
@@ -1,14 +1,40 @@
**hivemind.optim**
==================

.. automodule:: hivemind.optim
.. currentmodule:: hivemind.optim

.. raw:: html

This module contains decentralized optimizers that wrap regular pytorch optimizers to collaboratively train a shared model. Depending on the exact type, optimizer may average model parameters with peers, exchange gradients, or follow a more complicated distributed training strategy.
<br><br>

.. automodule:: hivemind.optim.experimental.optimizer
.. currentmodule:: hivemind.optim.experimental.optimizer

**hivemind.Optimizer**
----------------------

.. autoclass:: Optimizer
:members: step, zero_grad, load_state_from_peers, param_groups, shutdown
:member-order: bysource

.. currentmodule:: hivemind.optim.grad_scaler
.. autoclass:: GradScaler
:member-order: bysource


**CollaborativeOptimizer**
--------------------------

.. raw:: html

CollaborativeOptimizer is a legacy version of hivemind.Optimizer. **For new projects, please use hivemind.Optimizer.**
Currently, hivemind.Optimizer supports all the features of CollaborativeOptimizer and then some.
CollaborativeOptimizer will still be supported for awhile, but will eventually be deprecated.
<br><br>


.. automodule:: hivemind.optim.collaborative
.. currentmodule:: hivemind.optim

.. autoclass:: CollaborativeOptimizer
:members: step
:member-order: bysource
Expand Down
2 changes: 2 additions & 0 deletions hivemind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
DecentralizedOptimizer,
DecentralizedOptimizerBase,
DecentralizedSGD,
GradScaler,
Optimizer,
TrainingAverager,
)
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID, PeerInfo
Expand Down
33 changes: 28 additions & 5 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from hivemind.utils import MPFuture, TensorDescriptor, get_logger
from hivemind.utils.asyncio import (
achain,
afirst,
aiter_with_timeout,
anext,
as_aiter,
Expand Down Expand Up @@ -413,11 +414,28 @@ async def _step(self, *, step: StepControl, future_for_init: MPFuture):
step.attach(trigger, cancel)
future_for_init.set_result((trigger, cancel))

async def find_peers_or_notify_cancel():
group_info = await self._matchmaking.look_for_group(step)
try:
if not step.triggered:
step.stage = AveragingStage.AWAITING_TRIGGER
await step.wait_for_trigger()
return group_info
except asyncio.CancelledError:
await asyncio.wait(
{
self._send_error_to_peer(peer_id, group_info.group_id, averaging_pb2.CANCELLED)
for peer_id in group_info.peer_ids
if peer_id != self.peer_id
}
)
raise

while not step.done():
try:
self._pending_group_assembled.clear()
step.stage = AveragingStage.LOOKING_FOR_GROUP
matchmaking_task = asyncio.create_task(self._matchmaking.look_for_group(step))
matchmaking_task = asyncio.create_task(find_peers_or_notify_cancel())
check_cancel_task = asyncio.create_task(step.wait_for_cancel())

await asyncio.wait({matchmaking_task, check_cancel_task}, return_when=asyncio.FIRST_COMPLETED)
Expand All @@ -428,13 +446,10 @@ async def _step(self, *, step: StepControl, future_for_init: MPFuture):
check_cancel_task.cancel()

group_info = await matchmaking_task

if group_info is None:
raise AllreduceException("Averaging step failed: could not find a group.")

if not step.triggered:
step.stage = AveragingStage.AWAITING_TRIGGER
await step.wait_for_trigger()

step.stage = AveragingStage.RUNNING_ALLREDUCE

step.set_result(
Expand Down Expand Up @@ -478,6 +493,14 @@ async def _step(self, *, step: StepControl, future_for_init: MPFuture):
)
)

async def _send_error_to_peer(self, peer_id: PeerID, group_id: GroupID, code: averaging_pb2.MessageCode):
try:
error = averaging_pb2.AveragingData(group_id=group_id, code=code)
stub = type(self).get_stub(self._p2p, peer_id, namespace=self.prefix)
await afirst(await stub.rpc_aggregate_part(as_aiter(error)))
except Exception as e:
logger.debug(f"Caught {e} when sending error {averaging_pb2.MessageCode.Name(code)} to {peer_id}.")

async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> GatheredData:
"""Run All-Reduce in a given group and update tensors in place, return gathered metadata"""
try:
Expand Down
9 changes: 9 additions & 0 deletions hivemind/averaging/control.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import struct
from enum import Enum
from typing import Optional
Expand Down Expand Up @@ -144,6 +145,14 @@ def __setstate__(self, state):
self._trigger, self._cancel, self._shared_buffer = state["_trigger"], state["_cancel"], state["_shared_buffer"]
self._data_for_gather, self._deadline, self._allow_retries = state["immutable_params"]

def __del__(self):
if os.getpid() == self._origin_pid and not self.triggered:
logger.warning(
"Deleted an averaging StepControl, but the step was not triggered. This may cause other "
"peers to fail an averaging round via TimeoutError."
)
super().__del__()

def cancel(self) -> bool:
if self._trigger is not None:
self._trigger.cancel()
Expand Down
8 changes: 5 additions & 3 deletions hivemind/averaging/matchmaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,11 @@ def __init__(
async def looking_for_group(self, step_control: StepControl):
async with self.lock_looking_for_group:
assert self.step_control is None
self.step_control = step_control
yield
self.step_control = None
try:
self.step_control = step_control
yield
finally:
self.step_control = None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this ensures that step_control will be none if matchmaking is cancelled


@property
def is_looking_for_group(self):
Expand Down
2 changes: 1 addition & 1 deletion hivemind/averaging/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
compression: CompressionBase = NoCompression(),
part_size_bytes: int = DEFAULT_PART_SIZE_BYTES,
tensor_infos: Optional[Sequence[CompressionInfo]] = None,
prefetch: int = 5,
prefetch: int = 1,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This prevents peers from eating up all CPU cores

):
if tensor_infos is None:
tensor_infos = tuple(CompressionInfo.from_tensor(x, key=i) for i, x in enumerate(tensors))
Expand Down
3 changes: 2 additions & 1 deletion hivemind/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from hivemind.optim.adaptive import CollaborativeAdaptiveOptimizer
from hivemind.optim.base import DecentralizedOptimizerBase
from hivemind.optim.collaborative import CollaborativeOptimizer
from hivemind.optim.grad_scaler import HivemindGradScaler
from hivemind.optim.experimental.optimizer import Optimizer
from hivemind.optim.grad_scaler import GradScaler, HivemindGradScaler
from hivemind.optim.simple import DecentralizedAdam, DecentralizedOptimizer, DecentralizedSGD
from hivemind.optim.training_averager import TrainingAverager
4 changes: 2 additions & 2 deletions hivemind/optim/collaborative.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindG
self.averager.local_step = self.collaboration_state.optimizer_step
logger.log(self.status_loglevel, f"Catching up with collaboration step {self.local_step}.")

if grad_scaler is not None and not grad_scaler.are_grads_finite(self.opt):
if grad_scaler is not None and not grad_scaler.are_grads_finite(self):
logger.log(self.status_loglevel, "Encountered incorrect value in fp16 grads, resetting local gradients")
self.local_samples_accumulated = self.local_steps_accumulated = 0
self.reset_accumulated_grads_()
Expand Down Expand Up @@ -310,7 +310,7 @@ def step(self, batch_size: Optional[int] = None, grad_scaler: Optional[HivemindG

if grad_scaler is not None:
with grad_scaler.running_global_step():
assert grad_scaler.step(self.opt)
assert grad_scaler.step(self)
else:
self.opt.step()

Expand Down
Loading