Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Optimize small batch case for PagedAttention #27847

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 86 additions & 19 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,14 +939,14 @@ struct MHAHelper {
// wv_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size]
void exec_kernel_multiple(const PlainTensor& query, const PlainTensor& present_value, const PlainTensor& output_emb,
const PlainTensor& qk_scratch_b, const PlainTensor& wv_scratch_b, const int32_t* block_table, size_t ithr, size_t q_blk,
size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) {
size_t hq_beg, size_t hq_end, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) {
auto q_start = q_blk * _block_size;
auto q_end = std::min(q_start + _block_size, q_len);
auto q_cnt = q_end - q_start;
constexpr bool q_is_xf16 = one_of(precision_of<DATA_TYPE>::value, ov::element::bf16, ov::element::f16);
constexpr bool q_cache_is_same = precision_of<DATA_TYPE>::value == precision_of<KVCACHE_TYPE>::value;
auto cur_kv_len_blocks = div_up(cur_kv_len, _block_size);
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
auto* q_ptr = query.ptr<DATA_TYPE>(h, q_start, 0);
float* c_ptr = _weight.ptr<float>(ithr, h, 0, 0);
// for each query block, loop through all key block
Expand Down Expand Up @@ -1065,13 +1065,14 @@ struct MHAHelper {
// weight: [nthr, H, 32, rnd_up(kv_len, block_size)]
// output: [nthr, 32, H, S]
void exec_kernel_one_bh(const PlainTensor& query, const PlainTensor& present_key, const PlainTensor& present_value, const PlainTensor& output_emb,
const int32_t* block_table, size_t ithr, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) {
const int32_t* block_table, size_t ithr, size_t hq_beg, size_t hq_end, size_t hk,
size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) {
if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) {
_gemv->tile_config();
for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) {
auto block_number = block_table[i];
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
(*_gemv)(query.ptr<DATA_TYPE>(h, pq), present_key.ptr<KVCACHE_TYPE>(block_number, hk),
_weight.ptr<float>(ithr, h, pq) + pk);
}
Expand All @@ -1082,7 +1083,7 @@ struct MHAHelper {
for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) {
auto block_number = block_table[i];
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
dot_product_block(query.ptr<DATA_TYPE>(h, pq), present_key.ptr<KVCACHE_TYPE>(block_number, hk),
_weight.ptr<float>(ithr, h, pq) + pk, _S, std::min(_block_size, cur_kv_len - pk));
}
Expand All @@ -1091,7 +1092,7 @@ struct MHAHelper {
}

for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
// apply attention mask & sofmax
float* alibi_lookup = nullptr;
float alibi_slope = 0.f;
Expand Down Expand Up @@ -1122,7 +1123,7 @@ struct MHAHelper {
auto block_number = block_table[i];
auto* v = present_value.ptr<KVCACHE_TYPE>(block_number, hk);
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
attn_acc_value_block(_output.ptr<float>(ithr, pq, h),
_weight.ptr<float>(ithr, h, pq) + pv,
v,
Expand All @@ -1133,7 +1134,7 @@ struct MHAHelper {
}
// convert to dst
for (size_t pq = 0; pq < q_len; pq++)
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++)
for (size_t h = hq_beg; h < hq_end; h++)
cvt_copy(output_emb.ptr<DATA_TYPE>(pq, h * _SV), _output.ptr<float>(ithr, pq, h), _SV);
}

Expand Down Expand Up @@ -1162,33 +1163,58 @@ struct MHAHelper {
// aligned to cache line (64bytes=16*sizeof(float)) to avoid false sharing
_weight_bhl.resize<float>({B, _H, q_len, rnd_up(max_context_len, std::max(_block_size, size_t{16}))});

parallel_for3d_dynamic(B, kv_len_in_blocks, _Hk, [&](size_t b, size_t pk_in_blocks, size_t hk) {
// for small batches dynamic scheduler has notable overhead
bool prefer_static_loop;
bool loop_hk = B * kv_len_in_blocks * _Hk > 2 * _nthr;
if (B <= 32) {
prefer_static_loop = true;
// small batch and all batch size is same(like SDPA case)
auto kv_len = past_lens.ptr<int32_t>()[0];
for (size_t b = 1; b < B; b++) {
if (past_lens.ptr<int32_t>()[b] != kv_len)
prefer_static_loop = false;
}
} else {
// bigger batch, probably it's vllm path, skip the test to save the cost
prefer_static_loop = false;
}
auto loop_qk = [&](size_t b, size_t pk_in_blocks, size_t hx) {
auto context_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
size_t hk, hq_beg, hq_end;
if (loop_hk) {
hk = hx;
hq_beg = hk * _h_each_group_len;
hq_end = (hk + 1) * _h_each_group_len;
} else {
hq_beg = hx;
hq_end = hx + 1;
hk = hx / _h_each_group_len;
}
// kv_len must be valid
auto pk = pk_in_blocks * _block_size;
if (pk < context_len) {
auto block_number = block_indices.ptr<int32_t>()[block_indices_begins.ptr<int32_t>()[b] + pk_in_blocks];
if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) {
_gemv->tile_config();
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
(*_gemv)(query.ptr<DATA_TYPE>(b, h, pq), present_key.ptr<KVCACHE_TYPE>(block_number, hk),
_weight_bhl.ptr<float>(b, h, pq) + pk);
}
}
_gemv->tile_release();
} else {
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
dot_product_block(query.ptr<DATA_TYPE>(b, h, pq), present_key.ptr<KVCACHE_TYPE>(block_number, hk),
_weight_bhl.ptr<float>(b, h, pq) + pk, _S, std::min(_block_size, context_len - pk));
}
}
}
}
});
};

parallel_for3d_dynamic(B, _H, q_len, [&](size_t b, size_t h, size_t pq) {
auto loop_softmax = [&](size_t b, size_t h, size_t pq) {
auto cur_kv_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
auto ncausal = cur_kv_len;
// apply attention mask & sofmax
Expand All @@ -1210,7 +1236,14 @@ struct MHAHelper {
ov::element::f32,
ov::element::f32,
alibi_slope);
});
};
if (prefer_static_loop) {
parallel_for3d(B, kv_len_in_blocks, loop_hk ? _Hk : _H, loop_qk);
parallel_for3d(B, _H, q_len, loop_softmax);
} else {
parallel_for3d_dynamic(B, kv_len_in_blocks, loop_hk ? _Hk : _H, loop_qk);
parallel_for3d_dynamic(B, _H, q_len, loop_softmax);
}

if (output_score) {
parallel_for2d_dynamic(B, q_len, [&](size_t b, size_t pq) {
Expand All @@ -1229,16 +1262,26 @@ struct MHAHelper {
memset(_output_bhl.ptr<float>(ithr, 0, 0, 0, 0), 0, _output_bhl.stride(0) * sizeof(float));
});

parallel_for3d_dynamic(B, kv_len_in_blocks, _Hk, [&](size_t b, size_t pv_in_blocks, size_t hk) {
auto loop_wk = [&](size_t b, size_t pv_in_blocks, size_t hx) {
auto ithr = parallel_get_thread_num();
auto context_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
auto pv = pv_in_blocks * _block_size;
size_t hk, hq_beg, hq_end;
if (loop_hk) {
hk = hx;
hq_beg = hk * _h_each_group_len;
hq_end = (hk + 1) * _h_each_group_len;
} else {
hq_beg = hx;
hq_end = hx + 1;
hk = hx / _h_each_group_len;
}
// kv_len must be valid
if (pv < context_len) {
auto block_number = block_indices.ptr<int32_t>()[block_indices_begins.ptr<int32_t>()[b] + pv_in_blocks];
auto* v = present_value.ptr<KVCACHE_TYPE>(block_number, hk);
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
attn_acc_value_block(_output_bhl.ptr<float>(ithr, b, pq, h),
_weight_bhl.ptr<float>(b, h, pq) + pv,
v,
Expand All @@ -1247,7 +1290,13 @@ struct MHAHelper {
}
}
}
});
};

if (prefer_static_loop) {
parallel_for3d(B, kv_len_in_blocks, loop_hk ? _Hk : _H, loop_wk);
} else {
parallel_for3d_dynamic(B, kv_len_in_blocks, loop_hk ? _Hk : _H, loop_wk);
}

parallel_for3d(B, _H, q_len, [&](size_t b, size_t h, size_t pq) {
auto* temp = _output_bhl.ptr<float>(0, b, pq, h);
Expand Down Expand Up @@ -1416,7 +1465,23 @@ struct MHA {
}
});

parallel_for2d_dynamic(attn_work_count, Hk, [&](size_t w, size_t hk) {
// loop along HK dimension: if mixed first/second token and elements count is enough, loop HK to reuse KV in the CPU cache
// else if elements count is small, prefer to loop H to get more work to avoid thread imbalance
bool loop_hk = _workitems.get_reorder_max_batch_size() == past_lens.m_dims[0] || // if only first token, loop H
attn_work_count * Hk <= 2 * _helper._nthr ? false : true; // or less than 2 work items per thread, loop H

parallel_for2d_dynamic(attn_work_count, loop_hk ? Hk : _helper._H, [&](size_t w, size_t hx) {
size_t hk, hq_beg, hq_end;
if (loop_hk) {
hk = hx;
hq_beg = hk * _helper._h_each_group_len;
hq_end = (hk + 1) * _helper._h_each_group_len;
} else {
hq_beg = hx;
hq_end = hx + 1;
hk = hx / _helper._h_each_group_len;
}

const auto& item = _workitems.get_attn_work_item(w);
const auto batch_in_seq = item.batch_in_seq;
const auto batch_in_token = subsequence_begins.ptr<int32_t>()[batch_in_seq];
Expand All @@ -1434,7 +1499,7 @@ struct MHA {
_helper.exec_kernel_one_bh(q.slice(0, batch_in_token, batch_in_token), k_cache, v_cache,
output_emb.slice(0, batch_in_token, batch_in_token),
block_indices.ptr<int32_t>() + block_indices_begins.ptr<int32_t>()[batch_in_seq],
ithr, hk, 1ul, cur_kv_len, alibi_slopes,
ithr, hq_beg, hq_end, hk, 1ul, cur_kv_len, alibi_slopes,
score_output);
} else {
const auto batch_in_reorder = item.batch_in_reorder;
Expand All @@ -1461,6 +1526,8 @@ struct MHA {
block_indices.ptr<int32_t>() + block_indices_begins.ptr<int32_t>()[batch_in_seq],
ithr,
q_blk,
hq_beg,
hq_end,
hk,
q_len,
cur_kv_len,
Expand Down
Loading