diff --git a/src/cute-gemm-tma-gma/gemm.cu b/src/cute-gemm-tma-gma/gemm.cu index cf53316..18dedbe 100644 --- a/src/cute-gemm-tma-gma/gemm.cu +++ b/src/cute-gemm-tma-gma/gemm.cu @@ -308,8 +308,10 @@ __global__ static void gemm_device( constexpr int kTmaTransactionBytes = size(sA) * sizeof_bits_v / 8 + size(sB) * sizeof_bits_v / 8; + cfk::barrierInit(tma_load_mbar[0], 1); + cfk::copy(tAgA(_, stage), tBgB(_, stage), tAsA(_, 0), tBsB(_, 0), - tma_load_a, tma_load_b, tma_load_mbar, mcast_mask_a, + tma_load_a, tma_load_b, tma_load_mbar[0], mcast_mask_a, mcast_mask_b); cfk::gemm(tiled_mma, tCrA, tCrB, tCrC);