From dab00be872e6cb9cefa1fdf8cbd0b9448eb8092e Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 6 Nov 2023 13:49:51 -0500 Subject: [PATCH 1/9] logs --- megablocks/layers/moe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index f73f0aea..2e13c589 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -374,6 +374,8 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. + print(f'Recv counts: {recv_counts}') + print(f'Tokens per expert: {parallel_tokens_per_expert}') parallel_x_handle.wait() parallel_x = self.permute_and_compute( parallel_x, From b6237d889f18066743f88bf413c85f0ad77ecd66 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 6 Nov 2023 14:37:42 -0500 Subject: [PATCH 2/9] logs --- megablocks/layers/moe.py | 1 + 1 file changed, 1 insertion(+) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 2e13c589..6f2c58c2 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -299,6 +299,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # TODO(tgale): It might be faster to do this on the GPU and # then communicate the results back to the host. send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) + print(f'Parallel tokens per expert: {parallel_tokens_per_expert}') recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1) # Convert the send/recv counts to lists. From 38cb36d400eff887232e1bc41d0a6fc65c20f6cd Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 6 Nov 2023 15:02:25 -0500 Subject: [PATCH 3/9] logs --- megablocks/layers/moe.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 6f2c58c2..c9714669 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -299,8 +299,8 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # TODO(tgale): It might be faster to do this on the GPU and # then communicate the results back to the host. send_counts = repeated_tokens_per_expert.cpu().sum(dim=-1) - print(f'Parallel tokens per expert: {parallel_tokens_per_expert}') - recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) # Convert the send/recv counts to lists. send_counts = send_counts.tolist() @@ -357,6 +357,8 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Calculate the bins boundaries from the token counts. parallel_tokens_per_expert = parallel_tokens_per_expert.sum( dim=0, dtype=torch.int) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert_cpu.sum( + dim=0, dtype=torch.int) parallel_bins = ops.inclusive_cumsum( parallel_tokens_per_expert, 0) parallel_bins = ( @@ -376,7 +378,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. print(f'Recv counts: {recv_counts}') - print(f'Tokens per expert: {parallel_tokens_per_expert}') + print(f'Tokens per expert CPU: {parallel_tokens_per_expert_cpu}') parallel_x_handle.wait() parallel_x = self.permute_and_compute( parallel_x, From f3995527be416629c0aaa3f53a8128f9a9f9b8d1 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 6 Nov 2023 15:18:24 -0500 Subject: [PATCH 4/9] logs --- megablocks/layers/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index c9714669..a29a48f3 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -377,7 +377,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. - print(f'Recv counts: {recv_counts}') + print(f'Tokens per expert: {parallel_tokens_per_expert}') print(f'Tokens per expert CPU: {parallel_tokens_per_expert_cpu}') parallel_x_handle.wait() parallel_x = self.permute_and_compute( From 4391127898b40d31769f57d77a7fb794ecc8b537 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 6 Nov 2023 15:29:52 -0500 Subject: [PATCH 5/9] cast to cpu --- megablocks/layers/moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index a29a48f3..48812dbc 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -377,8 +377,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. - print(f'Tokens per expert: {parallel_tokens_per_expert}') - print(f'Tokens per expert CPU: {parallel_tokens_per_expert_cpu}') + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu if isinstance(self.mlp, mlp.GroupedMLP) else parallel_tokens_per_expert parallel_x_handle.wait() parallel_x = self.permute_and_compute( parallel_x, From 1d91b99cd1a01d26d44f31d5940a27f562be78af Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 6 Nov 2023 16:35:17 -0500 Subject: [PATCH 6/9] simplify --- megablocks/layers/moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 48812dbc..f0bea9fd 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -357,8 +357,6 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Calculate the bins boundaries from the token counts. parallel_tokens_per_expert = parallel_tokens_per_expert.sum( dim=0, dtype=torch.int) - parallel_tokens_per_expert_cpu = parallel_tokens_per_expert_cpu.sum( - dim=0, dtype=torch.int) parallel_bins = ops.inclusive_cumsum( parallel_tokens_per_expert, 0) parallel_bins = ( @@ -377,7 +375,9 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. - parallel_tokens_per_expert = parallel_tokens_per_expert_cpu if isinstance(self.mlp, mlp.GroupedMLP) else parallel_tokens_per_expert + if isinstance(self.mlp, mlp.GroupedMLP): # GroupedMLP requires counts on CPU + parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( + dim=0, dtype=torch.int) parallel_x_handle.wait() parallel_x = self.permute_and_compute( parallel_x, From 673f645bc20efc5381bbd50a157a0b93e3c52f8b Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 6 Nov 2023 16:47:59 -0500 Subject: [PATCH 7/9] move comment --- megablocks/layers/moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index f0bea9fd..56ec686d 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -375,9 +375,9 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. - if isinstance(self.mlp, mlp.GroupedMLP): # GroupedMLP requires counts on CPU + if isinstance(self.mlp, mlp.GroupedMLP): parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, dtype=torch.int) + dim=0, dtype=torch.int) # GroupedMLP requires counts on CPU parallel_x_handle.wait() parallel_x = self.permute_and_compute( parallel_x, From 1a766e1f4e93d6bcef89aca6211dc2a808bd60a4 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 6 Nov 2023 16:54:17 -0500 Subject: [PATCH 8/9] updated commit --- megablocks/layers/moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 56ec686d..175ae84f 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -376,8 +376,11 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. if isinstance(self.mlp, mlp.GroupedMLP): + # GroupedMLP requires counts on CPU. We can use the tensor already + # moved to CPU in the prior all_to_all, which avoids an extra + # device synchronization. parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, dtype=torch.int) # GroupedMLP requires counts on CPU + dim=0, dtype=torch.int) parallel_x_handle.wait() parallel_x = self.permute_and_compute( parallel_x, From 2066a3e03bba198efd7f74dae01e28e108387973 Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Mon, 6 Nov 2023 17:06:06 -0500 Subject: [PATCH 9/9] fix comment --- megablocks/layers/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 175ae84f..a74036a0 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -377,7 +377,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Block to make sure that the cross-device permutation is complete. if isinstance(self.mlp, mlp.GroupedMLP): # GroupedMLP requires counts on CPU. We can use the tensor already - # moved to CPU in the prior all_to_all, which avoids an extra + # moved to CPU for the prior all_to_all, which avoids an extra # device synchronization. parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( dim=0, dtype=torch.int)