Skip to content

Commit

Permalink
Do not run Accelerate Matumul for pre-volta gpus
Browse files Browse the repository at this point in the history
  • Loading branch information
Himanshu Pathak authored and geekypathak21 committed Apr 21, 2023
1 parent 192f889 commit 31ee8fc
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class BlockedToMMA : public mlir::RewritePattern {
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (computeCapability < 70)
return failure();
auto dotOp = cast<triton::DotOp>(op);
// TODO: Check data-types and SM compatibility
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
Expand Down
10 changes: 10 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from . import core as tl
from triton._C.libtriton.triton import ir
import torch
import triton

T = TypeVar('T')

Expand Down Expand Up @@ -1180,6 +1182,14 @@ def dot(lhs: tl.tensor,
allow_tf32: bool,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
if torch.version.hip is None:
device = triton.runtime.jit.get_current_device()
capability = triton.runtime.jit.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
if capability < 70:
assert (
not rhs.dtype.is_fp16() and not rhs.dtype.is_fp8()
), "Float8 and Float16 types are not supported for compute capability < 70 (use Float32 or above)"
assert lhs.type.is_block() and rhs.type.is_block()
assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!"
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
Expand Down

0 comments on commit 31ee8fc

Please sign in to comment.