diff --git a/fbgemm_gpu/experimental/gen_ai/src/gather_scatter/gather_scatter.cpp b/fbgemm_gpu/experimental/gen_ai/src/gather_scatter/gather_scatter.cpp index 633a0aca9..50b008492 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/gather_scatter/gather_scatter.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/gather_scatter/gather_scatter.cpp @@ -20,6 +20,22 @@ void scatter_add_along_first_dim( at::Tensor src, at::Tensor index); +at::Tensor gather_along_first_dim_meta( + const at::Tensor& data, + at::Tensor index) { + int K = data.size(1); + int N = index.size(0); + at::Tensor output = at::empty({N, K}, data.options()); + return output; +} + +void scatter_add_along_first_dim_meta( + const at::Tensor& dst, + const at::Tensor& src, + const at::Tensor& index) { + return; +} + TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.set_python_module("fbgemm_gpu.experimental.gen_ai.gather_scatter"); m.def("gather_along_first_dim(Tensor Data, Tensor Index) -> Tensor"); @@ -32,6 +48,10 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("scatter_add_along_first_dim", scatter_add_along_first_dim); } +TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { + m.impl("gather_along_first_dim", gather_along_first_dim_meta); + m.impl("scatter_add_along_first_dim", scatter_add_along_first_dim_meta); +} #endif } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/test/gather_scatter/gather_scatter_test.py b/fbgemm_gpu/experimental/gen_ai/test/gather_scatter/gather_scatter_test.py index 923c29ffc..23b36c099 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/gather_scatter/gather_scatter_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/gather_scatter/gather_scatter_test.py @@ -27,7 +27,9 @@ class GatherScatterTests(unittest.TestCase): """Test Gathers.""" def test_gather_along_first_dim(self) -> None: - def _test_gather_along_first_dim(M: int, N: int, K: int) -> None: + def _test_gather_along_first_dim( + M: int, N: int, K: int, compile: bool = False + ) -> None: logger.info(f"Running test_gather_along_first_dim: {M=}, {N=}, {K=}") src = torch.randn([M, K], device="cuda", dtype=torch.bfloat16).abs() if M == N: @@ -36,7 +38,10 @@ def _test_gather_along_first_dim(M: int, N: int, K: int) -> None: indices = torch.randint(0, M, [N], device="cuda", dtype=torch.int32) def fn(): - return torch.ops.fbgemm.gather_along_first_dim(src, indices) + op = torch.ops.fbgemm.gather_along_first_dim + if compile: + op = torch.compile(op, backend="inductor", fullgraph=True) + return op(src, indices) def ref_fn(): return torch.index_select(src, 0, indices) @@ -71,38 +76,41 @@ def ref_fn(): _test_gather_along_first_dim(255, 129, 2049) _test_gather_along_first_dim(255, 129, 2048) _test_gather_along_first_dim(1024, 1024, 1024) + _test_gather_along_first_dim(1024, 1024, 1024, compile=True) def test_scatter_add_along_first_dim(self) -> None: - def _test_scatter_add_along_first_dim(M: int, N: int, K: int) -> None: + def _test_scatter_add_along_first_dim( + M: int, N: int, K: int, compile: bool = False + ) -> None: logger.info(f"Running test_scatter_add_along_first_dim: {M=}, {N=}, {K=}") src = torch.randn([M, K], device="cuda", dtype=torch.bfloat16).abs() dst = torch.randn([N, K], device="cuda", dtype=torch.bfloat16).abs() if M == N: - indices = torch.randperm(N, device="cuda", dtype=torch.int32) + indices_1d = torch.randperm(N, device="cuda", dtype=torch.int64) else: - indices = torch.randint(0, N, [M], device="cuda", dtype=torch.int32) + indices_1d = torch.randint(0, N, [M], device="cuda", dtype=torch.int64) - indices_int32 = indices.to(torch.int32) - indices_int64 = indices.to(torch.int64).unsqueeze(1).expand(-1, K) + indices_2d = indices_1d.to(torch.int64).unsqueeze(1).expand(-1, K) test_dst = dst.clone() ref_dst = dst.clone() logger.info("Running FBGMM") - torch.ops.fbgemm.scatter_add_along_first_dim(test_dst, src, indices_int32) + torch.ops.fbgemm.scatter_add_along_first_dim(test_dst, src, indices_1d) logger.info("Running PyTorch") - ref_dst.scatter_add_(0, indices_int64, src) + ref_dst.scatter_add_(0, indices_2d, src) torch.testing.assert_close(test_dst, ref_dst, atol=1e-3, rtol=2e-2) def fn(): - torch.ops.fbgemm.scatter_add_along_first_dim( - test_dst, src, indices_int32 - ) + op = torch.ops.fbgemm.scatter_add_along_first_dim + if compile: + op = torch.compile(op, backend="inductor", fullgraph=True) + op(test_dst, src, indices_1d) def ref_fn(): - ref_dst.scatter_add_(0, indices_int64, src) + ref_dst.scatter_add_(0, indices_2d, src) # Load src, load dst, store dst. x3. data_size_in_terabytes = N * K * 2 * 3 / 1e12 @@ -127,6 +135,7 @@ def ref_fn(): _test_scatter_add_along_first_dim(255, 129, 2049) _test_scatter_add_along_first_dim(255, 129, 2048) _test_scatter_add_along_first_dim(1024, 1024, 1024) + _test_scatter_add_along_first_dim(1024, 1024, 1024, compile=True) if __name__ == "__main__":