Skip to content

Commit

Permalink
Merge pull request triton-lang#1 from dfukalov/dfukalov/work-5
Browse files Browse the repository at this point in the history
[Triton-MLIR][ROCM] Updated tests for Ld/St support.
  • Loading branch information
B1tway authored Nov 30, 2022
2 parents b2c1111 + f7bea48 commit c964485
Showing 1 changed file with 115 additions and 46 deletions.
161 changes: 115 additions & 46 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck %s
// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm | FileCheck --check-prefixes=CHECK,GCN %s

module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
// Here the 128 comes from the 4 in module attribute multiples 32
// XHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}}
// PTX: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}}
func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
// CHECK: llvm.return
return
Expand All @@ -17,7 +17,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_load
func @basic_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// GCN-SAME: global_load_dword $0, $1, off offset:0
// CHECK: llvm.inline_asm
// GCN-SAME: global_load_dword $0, $1, off offset:0
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
}
Expand All @@ -30,9 +32,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: vectorized_load
func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr<f32>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.b32
// GCN-SAME: global_load_dword $0, $1, off offset:0
// PTX-SAME: ld.global.b32
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.b32
// GCN-SAME: global_load_dword $0, $1, off offset:0
// PTX-SAME: ld.global.b32
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
return
}
Expand All @@ -45,9 +49,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: vectorized_load_f16
func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr<f16>, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.b16
// GCN-SAME: global_load_ushort $0, $1, off offset:0
// PTX-SAME: ld.global.b16
// CHECK: llvm.inline_asm
// CHECK-SAME: ld.global.b16
// GCN-SAME: global_load_ushort $0, $1, off offset:0
// PTX-SAME: ld.global.b16
%1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf16, #blocked0>
return
}
Expand Down Expand Up @@ -97,27 +103,39 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>

// Load 4 elements from vector0
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// GCN: llvm.inline_asm {{.*}}global_load_dword $0, $1, off offset:0
// GCN: llvm.inline_asm {{.*}}global_load_dword $0, $1, off offset:0
// GCN: llvm.inline_asm {{.*}}global_load_dword $0, $1, off offset:0
// GCN: llvm.inline_asm {{.*}}global_load_dword $0, $1, off offset:0
// PTX: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// PTX: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// PTX: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// PTX: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];

// Load 4 elements from vector1
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// GCN: llvm.inline_asm {{.*}}global_load_dword $0, $1, off offset:0
// GCN: llvm.inline_asm {{.*}}global_load_dword $0, $1, off offset:0
// GCN: llvm.inline_asm {{.*}}global_load_dword $0, $1, off offset:0
// GCN: llvm.inline_asm {{.*}}global_load_dword $0, $1, off offset:0
// PTX: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// PTX: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// PTX: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
// PTX: "@${{.*}} ld.global.b32 { ${{.*}} }, [ ${{.*}} + 0 ];
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
%12 = tt.splat %arg2 : (!tt.ptr<f32>) -> tensor<256x!tt.ptr<f32>, #blocked0>
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>

// Store 4 elements to global
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// GCN: llvm.inline_asm {{.*}}global_store_dword $1, $0, off offset:0
// GCN: llvm.inline_asm {{.*}}global_store_dword $1, $0, off offset:0
// GCN: llvm.inline_asm {{.*}}global_store_dword $1, $0, off offset:0
// GCN: llvm.inline_asm {{.*}}global_store_dword $1, $0, off offset:0
// PTX: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// PTX: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// PTX: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// PTX: @${{.*}} st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
}
Expand All @@ -141,10 +159,20 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>

// Load 4 elements from A with single one vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// GCN: llvm.inline_asm
// GCN-SAME: global_load_dword $0, $4, off offset:0
// GCN-SAME: global_load_dword $1, $4, off offset:4
// GCN-SAME: global_load_dword $2, $4, off offset:8
// GCN-SAME: global_load_dword $3, $4, off offset:12
// PTX: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];

// Load 4 elements from B with single one vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// GCN: llvm.inline_asm
// GCN-SAME: global_load_dword $0, $4, off offset:0
// GCN-SAME: global_load_dword $1, $4, off offset:4
// GCN-SAME: global_load_dword $2, $4, off offset:8
// GCN-SAME: global_load_dword $3, $4, off offset:12
// PTX: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];

%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
Expand All @@ -153,7 +181,12 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>

// Store 4 elements to global with single one vectorized store instruction
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
// GCN: llvm.inline_asm
// GCN-SAME: global_store_dword $4, $0, off offset:0
// GCN-SAME: global_store_dword $4, $1, off offset:4
// GCN-SAME: global_store_dword $4, $2, off offset:8
// GCN-SAME: global_store_dword $4, $3, off offset:12
// PTX: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
}
Expand All @@ -165,6 +198,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
// Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1.
module attributes {"triton_gpu.num-warps" = 2 : i32} {
// CHECK-LABEL: vecadd_masked_vec1
func @vecadd_masked_vec1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_elements: i32) {
%c64_i32 = arith.constant 64 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
Expand All @@ -179,7 +213,8 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
%9 = tt.splat %n_elements : (i32) -> tensor<64xi32, #blocked>
%10 = "triton_gpu.cmpi"(%4, %9) {predicate = 2 : i64} : (tensor<64xi32, #blocked>, tensor<64xi32, #blocked>) -> tensor<64xi1, #blocked>
// load op has a vector width = 1 due to the %mask's alignment
// CHECK: ld.global.b32
// GCN: llvm.inline_asm {{.*}}global_load_dword $0, $1, off offset:0
// PTX: ld.global.b32
%11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked>
%12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32, #blocked>
%13 = arith.addf %11, %12 : tensor<64xf32, #blocked>
Expand All @@ -195,7 +230,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: global_load_store_vec8
func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
func @global_load_store_vec8(%arg0: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 4 : i32}, %arg3: i32) {
%c256_i32 = arith.constant 256 : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
%1 = arith.muli %0, %c256_i32 : i32
Expand All @@ -208,12 +243,32 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>

// Load 8 elements from A with two vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// GCN: llvm.inline_asm
// GCN-SAME: global_load_dword $0, $4, off offset:0
// GCN-SAME: global_load_dword $1, $4, off offset:4
// GCN-SAME: global_load_dword $2, $4, off offset:8
// GCN-SAME: global_load_dword $3, $4, off offset:12
// GCN: llvm.inline_asm
// GCN-SAME: global_load_dword $0, $4, off offset:0
// GCN-SAME: global_load_dword $1, $4, off offset:4
// GCN-SAME: global_load_dword $2, $4, off offset:8
// GCN-SAME: global_load_dword $3, $4, off offset:12
// PTX: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// PTX: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];

// Load 8 elements from B with two vectorized load instruction
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// CHECK: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// GCN: llvm.inline_asm
// GCN-SAME: global_load_dword $0, $4, off offset:0
// GCN-SAME: global_load_dword $1, $4, off offset:4
// GCN-SAME: global_load_dword $2, $4, off offset:8
// GCN-SAME: global_load_dword $3, $4, off offset:12
// GCN: llvm.inline_asm
// GCN-SAME: global_load_dword $0, $4, off offset:0
// GCN-SAME: global_load_dword $1, $4, off offset:4
// GCN-SAME: global_load_dword $2, $4, off offset:8
// GCN-SAME: global_load_dword $3, $4, off offset:12
// PTX: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];
// PTX: @${{.*}} ld.global.v4.b32 { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} }, [ ${{.*}} + 0 ];

%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0>
Expand All @@ -222,8 +277,18 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
%13 = tt.addptr %12, %4 : tensor<256x!tt.ptr<f32>, #blocked0>

// Store 8 elements to global with two vectorized store instruction
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
// CHECK: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
// GCN: llvm.inline_asm
// GCN-SAME: global_store_dword $4, $0, off offset:0
// GCN-SAME: global_store_dword $4, $1, off offset:4
// GCN-SAME: global_store_dword $4, $2, off offset:8
// GCN-SAME: global_store_dword $4, $3, off offset:12
// GCN: llvm.inline_asm
// GCN-SAME: global_store_dword $4, $0, off offset:0
// GCN-SAME: global_store_dword $4, $1, off offset:4
// GCN-SAME: global_store_dword $4, $2, off offset:8
// GCN-SAME: global_store_dword $4, $3, off offset:12
// PTX: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
// PTX: @$5 st.global.v4.b32 [ ${{.*}} + 0 ], { ${{.*}}, ${{.*}}, ${{.*}}, ${{.*}} };
tt.store %13, %11 : tensor<256xf32, #blocked0>
return
}
Expand Down Expand Up @@ -263,7 +328,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_make_range
func @basic_make_range() {
// XHECK: nvvm.read.ptx.sreg.tid.x
// PTX: nvvm.read.ptx.sreg.tid.x
// CHECK: llvm.mlir.undef
// CHECK: llvm.insertvalue
// CHECK: llvm.insertvalue
Expand Down Expand Up @@ -303,7 +368,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_program_id
func @basic_program_id() {
// XHECK: nvvm.read.ptx.sreg.ctaid.x : i32
// PTX: nvvm.read.ptx.sreg.ctaid.x : i32
%0 = tt.get_program_id {axis = 0 : i32} : i32
return
}
Expand Down Expand Up @@ -554,9 +619,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: basic_store
func @basic_store(%ptrs: tensor<256x!tt.ptr<f32>, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) {
// CHECK: llvm.inline_asm
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// GCN-SAME: global_store_dword $1, $0, off offset:0
// PTX-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// CHECK: llvm.inline_asm
// CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
// GCN-SAME: global_store_dword $1, $0, off offset:0
// PTX-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} };
tt.store %ptrs, %vals, %mask : tensor<256xf32, #blocked0>
return
}
Expand All @@ -571,7 +638,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_layout_blocked_blocked
func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// XHECK: nvvm.barrier0
// PTX: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
Expand All @@ -588,7 +655,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// XHECK: nvvm.barrier0
// PTX: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<1xf32>, 3>
// CHECK: llvm.load
Expand Down Expand Up @@ -619,12 +686,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_layout_blocked_blocked_vec
func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// XHECK: nvvm.barrier0
// PTX: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// XHECK: nvvm.barrier0
// PTX: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
Expand All @@ -643,18 +710,18 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
// CHECK: llvm.mlir.addressof @global_smem
// XHECK: nvvm.barrier0
// PTX: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// XHECK: nvvm.barrier0
// PTX: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// XHECK: nvvm.barrier0
// PTX: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// XHECK: nvvm.barrier0
// PTX: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
// CHECK: llvm.load
Expand Down Expand Up @@ -707,12 +774,12 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8>
// CHECK-LABEL: convert_layout_mma_block
func @convert_layout_mma_blocked(%arg0: tensor<32x16xf32, #mma>) {
// XHECK: nvvm.barrier0
// PTX: nvvm.barrier0
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// CHECK: llvm.store
// CHECK-SAME: !llvm.ptr<vector<2xf32>, 3>
// XHECK: nvvm.barrier0
// PTX: nvvm.barrier0
// CHECK: llvm.load
// CHECK-SAME: !llvm.ptr<vector<4xf32>, 3>
%0 = triton_gpu.convert_layout %arg0 : (tensor<32x16xf32, #mma>) -> tensor<32x16xf32, #blocked0>
Expand Down Expand Up @@ -770,10 +837,11 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul_kernel_dot_operand_layout
func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
// PTX: ldmatrix.sync.aligned.m8n8.x4.shared.b16
%a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b>

Expand All @@ -795,10 +863,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: matmul884_kernel_dot_operand_layout
func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma>
// CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16
// PTX: ldmatrix.sync.aligned.m8n8.x4.shared.b16
%a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b>

Expand Down

0 comments on commit c964485

Please sign in to comment.