Skip to content

Commit

Permalink
Review fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Sep 24, 2024
1 parent acb2afc commit ce8ad06
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion test/kernel/test_autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_int_mm(self, device, dtype):
("cuda", torch.float16),
]
)
@unittest.skipIf(not is_H100, "Need H100")
@unittest.skipIf(not is_H100, "Needs H100")
def test_int_mm_float8(self, device, dtype):
from torchao.kernel import intmm

Expand Down
2 changes: 1 addition & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
) # (it seems the transpose makes cublas check the above j constraint on i)
try:
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
except Exception as e:
except Exception:
# fallback path, would run on H100 for float8 dtypes
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
Expand Down

0 comments on commit ce8ad06

Please sign in to comment.