-
Notifications
You must be signed in to change notification settings - Fork 176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support auxiliary participants in AllReduceProtocol #260
Conversation
Please make sure that AllreduceProtocol correctly handles all combinations of CLIENT/NODE/AUX among 3-4 peers (all permutations) |
for i, peer in enumerate(self.ordered_group_endpoints) | ||
if self.peer_modes[peer] != AveragingMode.CLIENT)) | ||
else: | ||
print(f'{self.endpoint} - NOT SENDING STUFF {self.peer_modes}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove debug print or convert to a logger call (IMO the first is preferable)
@@ -144,6 +166,7 @@ def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAver | |||
|
|||
async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor: | |||
""" Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """ | |||
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "auxiliary peers are disallowed from sending tensors" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "auxiliary peers are disallowed from sending tensors" | |
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors" |
@@ -87,6 +108,7 @@ def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor): | |||
assert source not in self.averaged_tensor_parts, "already registered the average from this peer" | |||
assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch" | |||
assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch" | |||
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "auxiliary peers do not have local tensors for sending" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "auxiliary peers do not have local tensors for sending" | |
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers do not have local tensors for sending" |
@@ -293,19 +301,20 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float, | |||
async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner: | |||
""" Use a group description found by Matchmaking to form AllreduceRunner """ | |||
try: | |||
weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered)) | |||
weights, throughputs, modes_ix, user_gathered = zip(*map(self.serializer.loads, group_info.gathered)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's still worth changing, not sure if _ix
is a common suffix for indices
Investigation report: Example debug prints (from #263 )
|
Resolve #233