-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Simplify ExpertBackend interface (#483)
- extract gradient clipping from ExpertBackend: this behavior can be achieved with a user-defined Optimizer - remove stats from ExpertBackend: this behavior can be achieved with a user-defined Scheduler - rename full_state -> state_dict, rationale: there is no "non-full" state in this context - rename ExpertBackend.expert -> ExpertBackend.module to avoid confusion Co-authored-by: Max Ryabinin <[email protected]>
- Loading branch information
1 parent
6c56a87
commit 5ea21a7
Showing
17 changed files
with
201 additions
and
171 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from hivemind.moe.server.dht_handler import declare_experts, get_experts | ||
from hivemind.moe.server.expert_backend import ExpertBackend | ||
from hivemind.moe.server.layers import register_expert_class | ||
from hivemind.moe.server.module_backend import ModuleBackend | ||
from hivemind.moe.server.server import Server, background_server |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import torch | ||
|
||
|
||
class OptimizerWrapper(torch.optim.Optimizer): | ||
"""A wrapper for pytorch.optim.Optimizer that forwards all methods to the wrapped optimizer""" | ||
|
||
def __init__(self, optim: torch.optim.Optimizer): | ||
super().__init__(optim.param_groups, optim.defaults) | ||
self.optim = optim | ||
|
||
@property | ||
def defaults(self): | ||
return self.optim.defaults | ||
|
||
@property | ||
def state(self): | ||
return self.optim.state | ||
|
||
def __getstate__(self): | ||
return self.optim.__getstate__() | ||
|
||
def __setstate__(self, state): | ||
self.optim.__setstate__(state) | ||
|
||
def __repr__(self): | ||
return f"{self.__class__.__name__}({repr(self.optim)})" | ||
|
||
def state_dict(self): | ||
return self.optim.state_dict() | ||
|
||
def load_state_dict(self, state_dict: dict) -> None: | ||
return self.optim.load_state_dict(state_dict) | ||
|
||
def step(self, *args, **kwargs): | ||
return self.optim.step(*args, **kwargs) | ||
|
||
def zero_grad(self, *args, **kwargs): | ||
return self.optim.zero_grad(*args, **kwargs) | ||
|
||
@property | ||
def param_groups(self): | ||
return self.optim.param_groups | ||
|
||
def add_param_group(self, param_group: dict) -> None: | ||
return self.optim.add_param_group(param_group) | ||
|
||
|
||
class ClippingWrapper(OptimizerWrapper): | ||
"""A wrapper of torch.Optimizer that clips gradients by global norm before each step""" | ||
|
||
def __init__(self, optim: torch.optim.Optimizer, clip_grad_norm: float): | ||
super().__init__(optim) | ||
self.clip_grad_norm = clip_grad_norm | ||
|
||
def step(self, *args, **kwargs): | ||
parameters = tuple(param for group in self.param_groups for param in group["params"]) | ||
torch.nn.utils.clip_grad_norm_(parameters, self.clip_grad_norm) | ||
return super().step(*args, **kwargs) |
Oops, something went wrong.