Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support auxiliary participants in AllReduceProtocol #260

Merged
merged 25 commits into from
May 21, 2021
Merged

Conversation

foksly
Copy link
Collaborator

@foksly foksly commented May 11, 2021

Resolve #233

@justheuristic
Copy link
Member

Please make sure that AllreduceProtocol correctly handles all combinations of CLIENT/NODE/AUX among 3-4 peers (all permutations)

@foksly foksly requested review from mryab and justheuristic May 12, 2021 18:55
@foksly foksly assigned leshanbog and unassigned leshanbog May 12, 2021
@foksly foksly requested a review from leshanbog May 12, 2021 18:55
@foksly foksly self-assigned this May 12, 2021
for i, peer in enumerate(self.ordered_group_endpoints)
if self.peer_modes[peer] != AveragingMode.CLIENT))
else:
print(f'{self.endpoint} - NOT SENDING STUFF {self.peer_modes}')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove debug print or convert to a logger call (IMO the first is preferable)

@@ -144,6 +166,7 @@ def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAver

async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torch.Tensor) -> torch.Tensor:
""" Send a part of local tensors and metadata to a single peer, receive the average for that part of tensors """
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "auxiliary peers are disallowed from sending tensors"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "auxiliary peers are disallowed from sending tensors"
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers are disallowed from sending tensors"

@@ -87,6 +108,7 @@ def register_averaged_part(self, source: Endpoint, averaged_part: torch.Tensor):
assert source not in self.averaged_tensor_parts, "already registered the average from this peer"
assert averaged_part.shape == self.local_tensor_parts[source].shape, "averaged part shape mismatch"
assert averaged_part.dtype == self.local_tensor_parts[source].dtype, "averaged part dtype mismatch"
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "auxiliary peers do not have local tensors for sending"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "auxiliary peers do not have local tensors for sending"
assert self.peer_modes[self.endpoint] != AveragingMode.AUX, "Auxiliary peers do not have local tensors for sending"

@@ -293,19 +301,20 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, weight: float,
async def _make_allreduce_runner(self, group_info: GroupInfo, min_vector_size: int, **kwargs) -> AllReduceRunner:
""" Use a group description found by Matchmaking to form AllreduceRunner """
try:
weights, throughputs, modes, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
weights, throughputs, modes_ix, user_gathered = zip(*map(self.serializer.loads, group_info.gathered))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's still worth changing, not sure if _ix is a common suffix for indices

@foksly
Copy link
Collaborator Author

foksly commented May 20, 2021

Investigation report:
Problem: Test with 2 aux and 2 client peers failed once in a few hundred attempts
Cause: both aux peers entered matchmaking with exactly the same dht_time -- and refused to request each other because of ">=" comparison (to be fixed in subsequent commit)

Example debug prints (from #263 )

=================================== FAILURES ===================================
___________________________ test_allreduce_once[2-2] ___________________________
n_clients = 2, n_aux = 2

    @pytest.mark.forked
    @pytest.mark.parametrize("n_clients", [0, 1, 2])
    @pytest.mark.parametrize("n_aux", [0, 1, 2])
    def test_allreduce_once(n_clients, n_aux):
>       _test_allreduce_once(n_clients, n_aux)

tests/test_averaging.py:86: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

n_clients = 2, n_aux = 2

    def _test_allreduce_once(n_clients, n_aux):
        dht = hivemind.DHT(start=True, endpoint=f'{hivemind.LOCALHOST}:*')
    
        n_peers = 4
        modes = [AveragingMode.CLIENT] * n_clients + [AveragingMode.AUX] * n_aux + [AveragingMode.NODE] * (n_peers - n_clients - n_aux)
        # random.shuffle(modes)
    
        tensors1 = [torch.randn(123), torch.zeros(3)]
        tensors2 = [torch.rand(123), torch.ones(3)]
        tensors3 = [-torch.rand(123), torch.arange(3).to(torch.float32)]
        tensors4 = [torch.randn(123) ** 3, torch.arange(3).to(torch.float32) / 2]
        peer_tensors = [tensors1, tensors2, tensors3, tensors4]
    
        reference = [sum(tensors[i] for tensors, mode in zip(peer_tensors, modes)
                     if mode != AveragingMode.AUX) / max(1, n_peers - n_aux) for i in range(len(tensors1))]
    
        averagers = [hivemind.DecentralizedAverager(tensors, dht=dht, target_group_size=4, averaging_expiration=15,
                                                    prefix='mygroup', listen=mode != AveragingMode.CLIENT, listen_on='127.0.0.1:*',
                                                    auxiliary=(mode == AveragingMode.AUX), start=True)
                     for tensors, mode in zip(peer_tensors, modes)]
        futures = []
        for averager in averagers:
            futures.append(averager.step(wait=False))
        for future in futures:
            result = future.result()
            for averager in averagers:
>               assert averager.endpoint in result, f"{modes}"
E               AssertionError: [<AveragingMode.CLIENT: 1>, <AveragingMode.CLIENT: 1>, <AveragingMode.AUX: 2>, <AveragingMode.AUX: 2>]
E               assert '127.0.0.1:41813' in {'127.0.0.1:44459': None, 'client::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa': None, 'client::59909d38-6f42-4340-adfa-87d57d5f6c3c': None}
E                +  where '127.0.0.1:41813' = DecentralizedAverager(127.0.0.1:41813).endpoint

tests/test_averaging.py:70: AssertionError
----------------------------- Captured stdout call -----------------------------
QQclient::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa - current queue = {}
QQclient::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa - awaiting queue update
QQclient::59909d38-6f42-4340-adfa-87d57d5f6c3c - current queue = {}
QQclient::59909d38-6f42-4340-adfa-87d57d5f6c3c - awaiting queue update
QQ127.0.0.1:44459 - current queue = {}
QQ127.0.0.1:44459 - awaiting queue updateQQ127.0.0.1:41813 - current queue = {}

QQ127.0.0.1:41813 - awaiting queue update
QQclient::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa - current queue = {}
QQclient::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa - awaiting queue update
QQ127.0.0.1:44459 - current queue = {}
QQ127.0.0.1:44459 - awaiting queue update
QQclient::59909d38-6f42-4340-adfa-87d57d5f6c3c - current queue = {}
QQclient::59909d38-6f42-4340-adfa-87d57d5f6c3c - awaiting queue update
QQ127.0.0.1:41813 - current queue = {}
QQ127.0.0.1:41813 - awaiting queue update
QQ127.0.0.1:41813 - current queue = {}
QQ127.0.0.1:41813 - awaiting queue update
QQ127.0.0.1:44459 - current queue = {}
QQ127.0.0.1:44459 - awaiting queue update
QQclient::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa - current queue = {'127.0.0.1:44459': ValueWithExpiration(value=1621532347.506267, expiration_time=1621532347.506267)}
QQclient::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa - yielding 127.0.0.1:44459
client::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa - REQUEST TO 127.0.0.1:44459
QQclient::59909d38-6f42-4340-adfa-87d57d5f6c3c - current queue = {'127.0.0.1:44459': ValueWithExpiration(value=1621532347.506267, expiration_time=1621532347.506267)}
QQclient::59909d38-6f42-4340-adfa-87d57d5f6c3c - yielding 127.0.0.1:44459
client::59909d38-6f42-4340-adfa-87d57d5f6c3c - REQUEST TO 127.0.0.1:44459
127.0.0.1:44459 - incoming request from client::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa (time=1621532332.5730908)
127.0.0.1:44459 - accepted request from client::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa
client::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa - joining the group of 127.0.0.1:44459; waiting for peers
127.0.0.1:44459 - incoming request from client::59909d38-6f42-4340-adfa-87d57d5f6c3c (time=1621532332.5739517)
127.0.0.1:44459 - accepted request from client::59909d38-6f42-4340-adfa-87d57d5f6c3c
client::59909d38-6f42-4340-adfa-87d57d5f6c3c - joining the group of 127.0.0.1:44459; waiting for peers
QQ127.0.0.1:41813 - current queue = {'127.0.0.1:44459': ValueWithExpiration(value=1621532347.506267, expiration_time=1621532347.506267)}
QQ127.0.0.1:41813 - awaiting queue update
QQ127.0.0.1:44459 - current queue = {'127.0.0.1:41813': ValueWithExpiration(value=1621532347.506267, expiration_time=1621532347.506267)}
QQ127.0.0.1:44459 - awaiting queue update
QQ127.0.0.1:41813 - current queue = {'127.0.0.1:44459': ValueWithExpiration(value=1621532347.506267, expiration_time=1621532347.506267)}
QQ127.0.0.1:41813 - awaiting queue update
QQ127.0.0.1:44459 - current queue = {'127.0.0.1:41813': ValueWithExpiration(value=1621532347.506267, expiration_time=1621532347.506267)}
QQ127.0.0.1:44459 - awaiting queue update
QQ127.0.0.1:41813 - current queue = {'127.0.0.1:44459': ValueWithExpiration(value=1621532347.506267, expiration_time=1621532347.506267)}
(...repeated several more times...)

QQ127.0.0.1:41813 - awaiting queue update
QQ127.0.0.1:44459 - current queue = {'127.0.0.1:41813': ValueWithExpiration(value=1621532347.506267, expiration_time=1621532347.506267)}
QQ127.0.0.1:44459 - awaiting queue update
QQ127.0.0.1:41813 - current queue = {'127.0.0.1:44459': ValueWithExpiration(value=1621532347.506267, expiration_time=1621532347.506267)}
QQ127.0.0.1:41813 - awaiting queue update
QQ127.0.0.1:44459 - current queue = {'127.0.0.1:41813': ValueWithExpiration(value=1621532347.506267, expiration_time=1621532347.506267)}
QQ127.0.0.1:44459 - awaiting queue update
127.0.0.1:44459 - beginning allreduce because time is up, group: {'client::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa': endpoint: "client::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa"
schema_hash: "K\250\212\000@B\331\376g\321#k-\336\364*\020<\254~"
expiration: 1621532347.5698905
gather: "\224\313?\360\000\000\000\000\000\000\300\001\304\001\300"
client_mode: true
, 'client::59909d38-6f42-4340-adfa-87d57d5f6c3c': endpoint: "client::59909d38-6f42-4340-adfa-87d57d5f6c3c"
schema_hash: "K\250\212\000@B\331\376g\321#k-\336\364*\020<\254~"
expiration: 1621532347.571211
gather: "\224\313?\360\000\000\000\000\000\000\300\001\304\001\300"
client_mode: true
} (plus 127.0.0.1:44459)target size = 4 (time=1621532347.5079453)
client::59909d38-6f42-4340-adfa-87d57d5f6c3c - beginning alreduce
client::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa - beginning alreduce
127.0.0.1:44459 - finished processing request from client::59909d38-6f42-4340-adfa-87d57d5f6c3c
127.0.0.1:44459 - finished processing request from client::2dd4f14d-6b9a-43b2-b754-52b4c29b21fa
QQ127.0.0.1:41813 - current queue = {}
QQ127.0.0.1:41813 - awaiting queue update
QQ127.0.0.1:41813 - current queue = {'127.0.0.1:44459': ValueWithExpiration(value=1621532362.5062153, expiration_time=1621532362.5062153)}
QQ127.0.0.1:41813 - yielding 127.0.0.1:44459
127.0.0.1:41813 - REQUEST TO 127.0.0.1:44459
127.0.0.1:44459 - incoming request from 127.0.0.1:41813 (time=1621532347.5505655)
127.0.0.1:44459 - rejected request from 127.0.0.1:41813
127.0.0.1:41813 - requested 127.0.0.1:44459 to be my leader, but got rejected with NOT_LOOKING_FOR_GROUP
127.0.0.1:44459 - finished processing request from 127.0.0.1:41813
QQ127.0.0.1:41813 - current queue = {}
QQ127.0.0.1:41813 - awaiting queue update
QQ127.0.0.1:41813 - current queue = {}
QQ127.0.0.1:41813 - awaiting queue update

@justheuristic justheuristic merged commit e58f65d into master May 21, 2021
@justheuristic justheuristic deleted the aux-peers branch May 21, 2021 01:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support auxiliary participants in AllReduceProtocol
5 participants