Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shape validation in GradientAverager #481

Merged
merged 1 commit into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions hivemind/optim/grad_averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ def __init__(
grad.detach().cpu().clone().share_memory_() for grad in self._grads_from_parameters()
)
else:
if all(
params_grad.size() == grad.size()
for param_grad, grad in zip(self._grads_from_parameters(), averaged_grad)
if any(
param_grad.size() != grad.size()
for param_grad, grad in zip(self._grads_from_parameters(), averaged_grads)
):
raise ValueError("Averaged gradients doesn't have same shape as gradients from parameters")
raise ValueError("Averaged gradients don't have same shape as gradients from parameters")
super().__init__(averaged_tensors=averaged_grads, dht=dht, prefix=prefix, client_mode=client_mode, **kwargs)

def _grads_from_parameters(self) -> Iterator[torch.Tensor]:
Expand Down
1 change: 0 additions & 1 deletion hivemind/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from hivemind.dht import DHT
from hivemind.optim.grad_averager import GradientAverager, GradientAveragerFactory
from hivemind.optim.grad_scaler import GradScaler
from hivemind.optim.power_sgd_averager import PowerSGDGradientAverager
from hivemind.optim.progress_tracker import LocalTrainingProgress, ProgressTracker
from hivemind.optim.state_averager import (
LRSchedulerBase,
Expand Down
3 changes: 1 addition & 2 deletions hivemind/optim/power_sgd_averager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import contextlib
import multiprocessing as mp
from enum import Enum
from typing import Any, Iterable, Optional, Sequence

Expand All @@ -9,7 +8,7 @@
from hivemind.averaging.allreduce import AveragingMode
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.load_balancing import load_balance_peers
from hivemind.averaging.matchmaking import Matchmaking, MatchmakingException
from hivemind.averaging.matchmaking import MatchmakingException
from hivemind.compression import CompressionInfo, TensorRole
from hivemind.dht import DHT
from hivemind.optim.grad_averager import GradientAverager
Expand Down
2 changes: 1 addition & 1 deletion hivemind/optim/state_averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from hivemind.averaging.control import StepControl
from hivemind.compression import CompressionInfo, TensorRole
from hivemind.optim.grad_scaler import GradScaler
from hivemind.utils import DHTExpiration, PerformanceEMA, get_dht_time, get_logger, nested_flatten, nested_pack
from hivemind.utils import DHTExpiration, PerformanceEMA, get_logger, nested_flatten, nested_pack

logger = get_logger(__name__)

Expand Down
22 changes: 22 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ def test_grad_averager(grad_averager_factory: GradientAveragerFactory):
assert not torch.allclose(model2.w.grad, ref_average)


@pytest.mark.forked
@pytest.mark.parametrize(
"grad_averager_factory",
[GradientAverager, partial(PowerSGDGradientAverager, averager_rank=1)],
)
def test_grad_averager_wrong_shape(grad_averager_factory: GradientAveragerFactory):
parameter_shape = (5, 5)
model = nn.ParameterDict({"w": nn.Parameter(torch.zeros(parameter_shape))})
dht = hivemind.DHT(start=True)

with pytest.raises(ValueError):
grad_averager_factory(
model.parameters(),
dht=dht,
prefix="test_fail",
target_group_size=2,
reuse_grad_buffers=False,
start=True,
averaged_grads=[torch.zeros(parameter_shape + (1,))],
)


@pytest.mark.forked
@pytest.mark.parametrize(
"offload_optimizer, reuse_tensors, sync_epoch_when_averaging",
Expand Down