Skip to content

Commit

Permalink
Adds shapes information to enable torch.compile. (#3724)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#807


Adds shape information to enable custom ops in torch.compile.

Differential Revision: D69993984
  • Loading branch information
levendlee authored and facebook-github-bot committed Feb 23, 2025
1 parent 221c2aa commit fbc8389
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ void scatter_add_along_first_dim(
at::Tensor src,
at::Tensor index);

at::Tensor gather_along_first_dim_meta(at::Tensor data, at::Tensor index) {
int K = data.size(1);
int N = index.size(0);
at::Tensor output = at::empty({N, K}, data.options());
return output;
}

void scatter_add_along_first_dim_meta(
at::Tensor /*dst*/,
at::Tensor /*src*/,
at::Tensor /*index*/) {
return;
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.set_python_module("fbgemm_gpu.experimental.gen_ai.gather_scatter");
m.def("gather_along_first_dim(Tensor Data, Tensor Index) -> Tensor");
Expand All @@ -32,6 +46,10 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
m.impl("scatter_add_along_first_dim", scatter_add_along_first_dim);
}

TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
m.impl("gather_along_first_dim", gather_along_first_dim_meta);
m.impl("scatter_add_along_first_dim", scatter_add_along_first_dim_meta);
}
#endif

} // namespace fbgemm_gpu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ class GatherScatterTests(unittest.TestCase):
"""Test Gathers."""

def test_gather_along_first_dim(self) -> None:
def _test_gather_along_first_dim(M: int, N: int, K: int) -> None:
def _test_gather_along_first_dim(
M: int, N: int, K: int, compile: bool = False
) -> None:
logger.info(f"Running test_gather_along_first_dim: {M=}, {N=}, {K=}")
src = torch.randn([M, K], device="cuda", dtype=torch.bfloat16).abs()
if M == N:
Expand All @@ -36,7 +38,10 @@ def _test_gather_along_first_dim(M: int, N: int, K: int) -> None:
indices = torch.randint(0, M, [N], device="cuda", dtype=torch.int32)

def fn():
return torch.ops.fbgemm.gather_along_first_dim(src, indices)
op = torch.ops.fbgemm.gather_along_first_dim
if compile:
op = torch.compile(op, backend="inductor", fullgraph=True)
return op(src, indices)

def ref_fn():
return torch.index_select(src, 0, indices)
Expand Down Expand Up @@ -71,38 +76,41 @@ def ref_fn():
_test_gather_along_first_dim(255, 129, 2049)
_test_gather_along_first_dim(255, 129, 2048)
_test_gather_along_first_dim(1024, 1024, 1024)
_test_gather_along_first_dim(1024, 1024, 1024, compile=True)

def test_scatter_add_along_first_dim(self) -> None:
def _test_scatter_add_along_first_dim(M: int, N: int, K: int) -> None:
def _test_scatter_add_along_first_dim(
M: int, N: int, K: int, compile: bool = False
) -> None:
logger.info(f"Running test_scatter_add_along_first_dim: {M=}, {N=}, {K=}")
src = torch.randn([M, K], device="cuda", dtype=torch.bfloat16).abs()
dst = torch.randn([N, K], device="cuda", dtype=torch.bfloat16).abs()
if M == N:
indices = torch.randperm(N, device="cuda", dtype=torch.int32)
indices_1d = torch.randperm(N, device="cuda", dtype=torch.int64)
else:
indices = torch.randint(0, N, [M], device="cuda", dtype=torch.int32)
indices_1d = torch.randint(0, N, [M], device="cuda", dtype=torch.int64)

indices_int32 = indices.to(torch.int32)
indices_int64 = indices.to(torch.int64).unsqueeze(1).expand(-1, K)
indices_2d = indices_1d.to(torch.int64).unsqueeze(1).expand(-1, K)

test_dst = dst.clone()
ref_dst = dst.clone()

logger.info("Running FBGMM")
torch.ops.fbgemm.scatter_add_along_first_dim(test_dst, src, indices_int32)
torch.ops.fbgemm.scatter_add_along_first_dim(test_dst, src, indices_1d)

logger.info("Running PyTorch")
ref_dst.scatter_add_(0, indices_int64, src)
ref_dst.scatter_add_(0, indices_2d, src)

torch.testing.assert_close(test_dst, ref_dst, atol=1e-3, rtol=2e-2)

def fn():
torch.ops.fbgemm.scatter_add_along_first_dim(
test_dst, src, indices_int32
)
op = torch.ops.fbgemm.scatter_add_along_first_dim
if compile:
op = torch.compile(op, backend="inductor", fullgraph=True)
op(test_dst, src, indices_1d)

def ref_fn():
ref_dst.scatter_add_(0, indices_int64, src)
ref_dst.scatter_add_(0, indices_2d, src)

# Load src, load dst, store dst. x3.
data_size_in_terabytes = N * K * 2 * 3 / 1e12
Expand All @@ -127,6 +135,7 @@ def ref_fn():
_test_scatter_add_along_first_dim(255, 129, 2049)
_test_scatter_add_along_first_dim(255, 129, 2048)
_test_scatter_add_along_first_dim(1024, 1024, 1024)
_test_scatter_add_along_first_dim(1024, 1024, 1024, compile=True)


if __name__ == "__main__":
Expand Down

0 comments on commit fbc8389

Please sign in to comment.