Skip to content

Commit

Permalink
Fix shape validation in GradientAverager
Browse files Browse the repository at this point in the history
  • Loading branch information
mryab committed Jun 8, 2022
1 parent 4a3d8fb commit 8e0036c
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 8 deletions.
8 changes: 4 additions & 4 deletions hivemind/optim/grad_averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ def __init__(
grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
)
else:
if all(
params_grad.size() == grad.size()
for param_grad, grad in zip(self._grads_from_parameters(), averaged_grad)
if any(
param_grad.size() != grad.size()
for param_grad, grad in zip(self._grads_from_parameters(), averaged_grads)
):
raise ValueError("Averaged gradients doesn't have same shape as gradients from parameters")
raise ValueError("Averaged gradients don't have same shape as gradients from parameters")
super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)

def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
Expand Down
1 change: 0 additions & 1 deletion hivemind/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from hivemind.dht import DHT
from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
from hivemind.optim.grad_scaler import GradScaler
from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
from hivemind.optim.state_averager import (
LRSchedulerBase,
Expand Down
3 changes: 1 addition & 2 deletions hivemind/optim/power_sgd_averager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import contextlib
import multiprocessing as mp
from enum import Enum
from typing import Any, Iterable, Optional, Sequence

Expand All @@ -9,7 +8,7 @@
from hivemind.averaging.allreduce import AveragingMode
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.load_balancing import load_balance_peers
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
from hivemind.averaging.matchmaking import MatchmakingException
from hivemind.compression import CompressionInfo, TensorRole
from hivemind.dht import DHT
from hivemind.optim.grad_averager import GradientAverager
Expand Down
2 changes: 1 addition & 1 deletion hivemind/optim/state_averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from hivemind.averaging.control import StepControl
from hivemind.compression import CompressionInfo, TensorRole
from hivemind.optim.grad_scaler import GradScaler
from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
from hivemind.utils import DHTExpiration, PerformanceEMA, get_logger, nested_flatten, nested_pack

logger = get_logger(__name__)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
assert not torch.allclose(model2.w.grad, ref_average)


@pytest.mark.forked
@pytest.mark.parametrize(
"grad_averager_factory",
[GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
)
def test_grad_averager_wrong_shape(grad_averager_factory: GradientAveragerFactory):
parameter_shape = (5, 5)
model = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
dht = hivemind.DHT(start=True)

with pytest.raises(ValueError):
grad_averager_factory(
model.parameters(),
dht=dht,
prefix="test_fail",
target_group_size=2,
reuse_grad_buffers=False,
start=True,
averaged_grads=[torch.zeros(parameter_shape + (1,))],
)


@pytest.mark.forked
@pytest.mark.parametrize(
"offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
Expand Down

0 comments on commit 8e0036c

Please sign in to comment.