diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index bef34881ca41bc..90167ac86a8e1a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -939,14 +939,14 @@ struct MHAHelper { // wv_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size] void exec_kernel_multiple(const PlainTensor& query, const PlainTensor& present_value, const PlainTensor& output_emb, const PlainTensor& qk_scratch_b, const PlainTensor& wv_scratch_b, const int32_t* block_table, size_t ithr, size_t q_blk, - size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) { + size_t hq_beg, size_t hq_end, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) { auto q_start = q_blk * _block_size; auto q_end = std::min(q_start + _block_size, q_len); auto q_cnt = q_end - q_start; constexpr bool q_is_xf16 = one_of(precision_of::value, ov::element::bf16, ov::element::f16); constexpr bool q_cache_is_same = precision_of::value == precision_of::value; auto cur_kv_len_blocks = div_up(cur_kv_len, _block_size); - for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + for (size_t h = hq_beg; h < hq_end; h++) { auto* q_ptr = query.ptr(h, q_start, 0); float* c_ptr = _weight.ptr(ithr, h, 0, 0); // for each query block, loop through all key block @@ -1065,13 +1065,14 @@ struct MHAHelper { // weight: [nthr, H, 32, rnd_up(kv_len, block_size)] // output: [nthr, 32, H, S] void exec_kernel_one_bh(const PlainTensor& query, const PlainTensor& present_key, const PlainTensor& present_value, const PlainTensor& output_emb, - const int32_t* block_table, size_t ithr, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) { + const int32_t* block_table, size_t ithr, size_t hq_beg, size_t hq_end, size_t hk, + size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) { if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) { _gemv->tile_config(); for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) { auto block_number = block_table[i]; for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + for (size_t h = hq_beg; h < hq_end; h++) { (*_gemv)(query.ptr(h, pq), present_key.ptr(block_number, hk), _weight.ptr(ithr, h, pq) + pk); } @@ -1082,7 +1083,7 @@ struct MHAHelper { for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) { auto block_number = block_table[i]; for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + for (size_t h = hq_beg; h < hq_end; h++) { dot_product_block(query.ptr(h, pq), present_key.ptr(block_number, hk), _weight.ptr(ithr, h, pq) + pk, _S, std::min(_block_size, cur_kv_len - pk)); } @@ -1091,7 +1092,7 @@ struct MHAHelper { } for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + for (size_t h = hq_beg; h < hq_end; h++) { // apply attention mask & sofmax float* alibi_lookup = nullptr; float alibi_slope = 0.f; @@ -1122,7 +1123,7 @@ struct MHAHelper { auto block_number = block_table[i]; auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + for (size_t h = hq_beg; h < hq_end; h++) { attn_acc_value_block(_output.ptr(ithr, pq, h), _weight.ptr(ithr, h, pq) + pv, v, @@ -1133,7 +1134,7 @@ struct MHAHelper { } // convert to dst for (size_t pq = 0; pq < q_len; pq++) - for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) + for (size_t h = hq_beg; h < hq_end; h++) cvt_copy(output_emb.ptr(pq, h * _SV), _output.ptr(ithr, pq, h), _SV); } @@ -1162,8 +1163,38 @@ struct MHAHelper { // aligned to cache line (64bytes=16*sizeof(float)) to avoid false sharing _weight_bhl.resize({B, _H, q_len, rnd_up(max_context_len, std::max(_block_size, size_t{16}))}); - parallel_for3d_dynamic(B, kv_len_in_blocks, _Hk, [&](size_t b, size_t pk_in_blocks, size_t hk) { + // for small batches dynamic scheduler has notable overhead + bool prefer_static_loop; + // if less than 2 work items per thread, loop H + bool loop_hk = B * kv_len_in_blocks * _Hk <= 2 * _nthr ? false : true; + if (B <= 32) { + prefer_static_loop = true; + // small batch and all batch size is same(like SDPA case) + auto kv_len = past_lens.ptr()[0]; + for (size_t b = 1; b < B; b++) { + if (past_lens.ptr()[b] != kv_len) + prefer_static_loop = false; + } + } else { + // for bigger batch skip the test to save the cost + prefer_static_loop = false; + } + auto get_h_params = [] (bool loop_hk, size_t hx, size_t h_each_group_len, size_t& hq_beg, size_t& hq_end, size_t& hk) { + if (loop_hk) { + hk = hx; + hq_beg = hk * h_each_group_len; + hq_end = (hk + 1) * h_each_group_len; + } else { + hq_beg = hx; + hq_end = hx + 1; + hk = hx / h_each_group_len; + } + }; + auto loop_qk = [&](size_t b, size_t pk_in_blocks, size_t hx) { auto context_len = static_cast(past_lens.ptr()[b]) + 1; + size_t hk, hq_beg, hq_end; + get_h_params(loop_hk, hx, _h_each_group_len, hq_beg, hq_end, hk); + // kv_len must be valid auto pk = pk_in_blocks * _block_size; if (pk < context_len) { @@ -1171,7 +1202,7 @@ struct MHAHelper { if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) { _gemv->tile_config(); for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + for (size_t h = hq_beg; h < hq_end; h++) { (*_gemv)(query.ptr(b, h, pq), present_key.ptr(block_number, hk), _weight_bhl.ptr(b, h, pq) + pk); } @@ -1179,16 +1210,16 @@ struct MHAHelper { _gemv->tile_release(); } else { for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + for (size_t h = hq_beg; h < hq_end; h++) { dot_product_block(query.ptr(b, h, pq), present_key.ptr(block_number, hk), _weight_bhl.ptr(b, h, pq) + pk, _S, std::min(_block_size, context_len - pk)); } } } } - }); + }; - parallel_for3d_dynamic(B, _H, q_len, [&](size_t b, size_t h, size_t pq) { + auto loop_softmax = [&](size_t b, size_t h, size_t pq) { auto cur_kv_len = static_cast(past_lens.ptr()[b]) + 1; auto ncausal = cur_kv_len; // apply attention mask & sofmax @@ -1210,7 +1241,16 @@ struct MHAHelper { ov::element::f32, ov::element::f32, alibi_slope); - }); + }; + + size_t h_dims = loop_hk ? _Hk : _H; + if (prefer_static_loop) { + parallel_for3d(B, kv_len_in_blocks, h_dims, loop_qk); + parallel_for3d(B, _H, q_len, loop_softmax); + } else { + parallel_for3d_dynamic(B, kv_len_in_blocks, h_dims, loop_qk); + parallel_for3d_dynamic(B, _H, q_len, loop_softmax); + } if (output_score) { parallel_for2d_dynamic(B, q_len, [&](size_t b, size_t pq) { @@ -1229,16 +1269,19 @@ struct MHAHelper { memset(_output_bhl.ptr(ithr, 0, 0, 0, 0), 0, _output_bhl.stride(0) * sizeof(float)); }); - parallel_for3d_dynamic(B, kv_len_in_blocks, _Hk, [&](size_t b, size_t pv_in_blocks, size_t hk) { + auto loop_wk = [&](size_t b, size_t pv_in_blocks, size_t hx) { auto ithr = parallel_get_thread_num(); auto context_len = static_cast(past_lens.ptr()[b]) + 1; auto pv = pv_in_blocks * _block_size; + size_t hk, hq_beg, hq_end; + get_h_params(loop_hk, hx, _h_each_group_len, hq_beg, hq_end, hk); + // kv_len must be valid if (pv < context_len) { auto block_number = block_indices.ptr()[block_indices_begins.ptr()[b] + pv_in_blocks]; auto* v = present_value.ptr(block_number, hk); for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) { + for (size_t h = hq_beg; h < hq_end; h++) { attn_acc_value_block(_output_bhl.ptr(ithr, b, pq, h), _weight_bhl.ptr(b, h, pq) + pv, v, @@ -1247,7 +1290,13 @@ struct MHAHelper { } } } - }); + }; + + if (prefer_static_loop) { + parallel_for3d(B, kv_len_in_blocks, loop_hk ? _Hk : _H, loop_wk); + } else { + parallel_for3d_dynamic(B, kv_len_in_blocks, loop_hk ? _Hk : _H, loop_wk); + } parallel_for3d(B, _H, q_len, [&](size_t b, size_t h, size_t pq) { auto* temp = _output_bhl.ptr(0, b, pq, h); @@ -1416,7 +1465,23 @@ struct MHA { } }); - parallel_for2d_dynamic(attn_work_count, Hk, [&](size_t w, size_t hk) { + // loop along HK dimension: if mixed first/second token and elements count is enough, loop HK to reuse KV in the CPU cache + // else if elements count is small, prefer to loop H to get more work to avoid thread imbalance + bool loop_hk = _workitems.get_reorder_max_batch_size() == past_lens.m_dims[0] || // if only first token, loop H + attn_work_count * Hk <= 2 * _helper._nthr ? false : true; // or less than 2 work items per thread, loop H + + parallel_for2d_dynamic(attn_work_count, loop_hk ? Hk : _helper._H, [&](size_t w, size_t hx) { + size_t hk, hq_beg, hq_end; + if (loop_hk) { + hk = hx; + hq_beg = hk * _helper._h_each_group_len; + hq_end = (hk + 1) * _helper._h_each_group_len; + } else { + hq_beg = hx; + hq_end = hx + 1; + hk = hx / _helper._h_each_group_len; + } + const auto& item = _workitems.get_attn_work_item(w); const auto batch_in_seq = item.batch_in_seq; const auto batch_in_token = subsequence_begins.ptr()[batch_in_seq]; @@ -1434,7 +1499,7 @@ struct MHA { _helper.exec_kernel_one_bh(q.slice(0, batch_in_token, batch_in_token), k_cache, v_cache, output_emb.slice(0, batch_in_token, batch_in_token), block_indices.ptr() + block_indices_begins.ptr()[batch_in_seq], - ithr, hk, 1ul, cur_kv_len, alibi_slopes, + ithr, hq_beg, hq_end, hk, 1ul, cur_kv_len, alibi_slopes, score_output); } else { const auto batch_in_reorder = item.batch_in_reorder; @@ -1461,6 +1526,8 @@ struct MHA { block_indices.ptr() + block_indices_begins.ptr()[batch_in_seq], ithr, q_blk, + hq_beg, + hq_end, hk, q_len, cur_kv_len,