Skip to content

Commit e8edea9

Browse files
committed
Fix: triton pytest skip
1 parent c1fa677 commit e8edea9

File tree

4 files changed

+23
-23
lines changed

4 files changed

+23
-23
lines changed

benchmarks/benchmark_blockwise_scaled_linear_triton.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
1-
import pandas as pd
21
import torch
3-
from tqdm import tqdm
42

53
if torch.cuda.is_available():
4+
import pandas as pd
5+
from tqdm import tqdm
66
from triton.testing import do_bench
77

8-
from torchao.float8.float8_utils import compute_error
9-
from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
10-
from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm
11-
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
12-
fp8_blockwise_act_quant,
13-
fp8_blockwise_weight_quant,
14-
)
15-
from torchao.quantization.quant_api import (
16-
_int4_symm_per_token_quant_cutlass,
17-
_int8_symm_per_token_reduced_range_quant_cutlass,
18-
)
19-
from torchao.utils import is_sm_at_least_89
8+
from torchao.float8.float8_utils import compute_error
9+
from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
10+
from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import (
11+
blockwise_fp8_gemm,
12+
)
13+
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
14+
fp8_blockwise_act_quant,
15+
fp8_blockwise_weight_quant,
16+
)
17+
from torchao.quantization.quant_api import (
18+
_int4_symm_per_token_quant_cutlass,
19+
_int8_symm_per_token_reduced_range_quant_cutlass,
20+
)
21+
from torchao.utils import is_sm_at_least_89
2022

2123

2224
def benchmark_microseconds(f, *args):

test/prototype/test_blockwise_triton.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pytest
22
import torch
33

4+
triton = pytest.importorskip("triton", reason="Triton required to run this test")
5+
46
from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm
57
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
68
fp8_blockwise_act_quant,

torchao/prototype/blockwise_fp8/blockwise_fp8_gemm_triton.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import torch
2-
3-
if torch.cuda.is_available():
4-
import triton
5-
import triton.language as tl
6-
from triton import Config
2+
import triton
3+
import triton.language as tl
4+
from triton import Config
75

86
# Original implementation at https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
97

torchao/prototype/blockwise_fp8/blockwise_quantization.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
from typing import Tuple
22

33
import torch
4-
5-
if torch.cuda.is_available():
6-
import triton
7-
import triton.language as tl
4+
import triton
5+
import triton.language as tl
86

97

108
@triton.jit

0 commit comments

Comments
 (0)