Skip to content

Commit

Permalink
[CPU] Optimize small batch case for PagedAttention (#27847)
Browse files Browse the repository at this point in the history
### Details:
 - *Generate more work items to avoid thread imbalance*
 - *...*

### Tickets:
 - *[156347](https://jira.devtools.intel.com/browse/CVS-156347)*
 - *[158477](https://jira.devtools.intel.com/browse/CVS-158477)*
  • Loading branch information
luo-cheng2021 authored Dec 6, 2024
1 parent ea72f30 commit d62effb
Showing 1 changed file with 86 additions and 19 deletions.
105 changes: 86 additions & 19 deletions src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,14 +939,14 @@ struct MHAHelper {
// wv_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size]
void exec_kernel_multiple(const PlainTensor& query, const PlainTensor& present_value, const PlainTensor& output_emb,
const PlainTensor& qk_scratch_b, const PlainTensor& wv_scratch_b, const int32_t* block_table, size_t ithr, size_t q_blk,
size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) {
size_t hq_beg, size_t hq_end, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) {
auto q_start = q_blk * _block_size;
auto q_end = std::min(q_start + _block_size, q_len);
auto q_cnt = q_end - q_start;
constexpr bool q_is_xf16 = one_of(precision_of<DATA_TYPE>::value, ov::element::bf16, ov::element::f16);
constexpr bool q_cache_is_same = precision_of<DATA_TYPE>::value == precision_of<KVCACHE_TYPE>::value;
auto cur_kv_len_blocks = div_up(cur_kv_len, _block_size);
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
auto* q_ptr = query.ptr<DATA_TYPE>(h, q_start, 0);
float* c_ptr = _weight.ptr<float>(ithr, h, 0, 0);
// for each query block, loop through all key block
Expand Down Expand Up @@ -1065,13 +1065,14 @@ struct MHAHelper {
// weight: [nthr, H, 32, rnd_up(kv_len, block_size)]
// output: [nthr, 32, H, S]
void exec_kernel_one_bh(const PlainTensor& query, const PlainTensor& present_key, const PlainTensor& present_value, const PlainTensor& output_emb,
const int32_t* block_table, size_t ithr, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) {
const int32_t* block_table, size_t ithr, size_t hq_beg, size_t hq_end, size_t hk,
size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) {
if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) {
_gemv->tile_config();
for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) {
auto block_number = block_table[i];
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
(*_gemv)(query.ptr<DATA_TYPE>(h, pq), present_key.ptr<KVCACHE_TYPE>(block_number, hk),
_weight.ptr<float>(ithr, h, pq) + pk);
}
Expand All @@ -1082,7 +1083,7 @@ struct MHAHelper {
for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) {
auto block_number = block_table[i];
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
dot_product_block(query.ptr<DATA_TYPE>(h, pq), present_key.ptr<KVCACHE_TYPE>(block_number, hk),
_weight.ptr<float>(ithr, h, pq) + pk, _S, std::min(_block_size, cur_kv_len - pk));
}
Expand All @@ -1091,7 +1092,7 @@ struct MHAHelper {
}

for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
// apply attention mask & sofmax
float* alibi_lookup = nullptr;
float alibi_slope = 0.f;
Expand Down Expand Up @@ -1122,7 +1123,7 @@ struct MHAHelper {
auto block_number = block_table[i];
auto* v = present_value.ptr<KVCACHE_TYPE>(block_number, hk);
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
attn_acc_value_block(_output.ptr<float>(ithr, pq, h),
_weight.ptr<float>(ithr, h, pq) + pv,
v,
Expand All @@ -1133,7 +1134,7 @@ struct MHAHelper {
}
// convert to dst
for (size_t pq = 0; pq < q_len; pq++)
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++)
for (size_t h = hq_beg; h < hq_end; h++)
cvt_copy(output_emb.ptr<DATA_TYPE>(pq, h * _SV), _output.ptr<float>(ithr, pq, h), _SV);
}

Expand Down Expand Up @@ -1162,33 +1163,63 @@ struct MHAHelper {
// aligned to cache line (64bytes=16*sizeof(float)) to avoid false sharing
_weight_bhl.resize<float>({B, _H, q_len, rnd_up(max_context_len, std::max(_block_size, size_t{16}))});

parallel_for3d_dynamic(B, kv_len_in_blocks, _Hk, [&](size_t b, size_t pk_in_blocks, size_t hk) {
// for small batches dynamic scheduler has notable overhead
bool prefer_static_loop;
// if less than 2 work items per thread, loop H
bool loop_hk = B * kv_len_in_blocks * _Hk <= 2 * _nthr ? false : true;
if (B <= 32) {
prefer_static_loop = true;
// small batch and all batch size is same(like SDPA case)
auto kv_len = past_lens.ptr<int32_t>()[0];
for (size_t b = 1; b < B; b++) {
if (past_lens.ptr<int32_t>()[b] != kv_len)
prefer_static_loop = false;
}
} else {
// for bigger batch skip the test to save the cost
prefer_static_loop = false;
}
auto get_h_params = [] (bool loop_hk, size_t hx, size_t h_each_group_len, size_t& hq_beg, size_t& hq_end, size_t& hk) {
if (loop_hk) {
hk = hx;
hq_beg = hk * h_each_group_len;
hq_end = (hk + 1) * h_each_group_len;
} else {
hq_beg = hx;
hq_end = hx + 1;
hk = hx / h_each_group_len;
}
};
auto loop_qk = [&](size_t b, size_t pk_in_blocks, size_t hx) {
auto context_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
size_t hk, hq_beg, hq_end;
get_h_params(loop_hk, hx, _h_each_group_len, hq_beg, hq_end, hk);

// kv_len must be valid
auto pk = pk_in_blocks * _block_size;
if (pk < context_len) {
auto block_number = block_indices.ptr<int32_t>()[block_indices_begins.ptr<int32_t>()[b] + pk_in_blocks];
if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) {
_gemv->tile_config();
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
(*_gemv)(query.ptr<DATA_TYPE>(b, h, pq), present_key.ptr<KVCACHE_TYPE>(block_number, hk),
_weight_bhl.ptr<float>(b, h, pq) + pk);
}
}
_gemv->tile_release();
} else {
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
dot_product_block(query.ptr<DATA_TYPE>(b, h, pq), present_key.ptr<KVCACHE_TYPE>(block_number, hk),
_weight_bhl.ptr<float>(b, h, pq) + pk, _S, std::min(_block_size, context_len - pk));
}
}
}
}
});
};

parallel_for3d_dynamic(B, _H, q_len, [&](size_t b, size_t h, size_t pq) {
auto loop_softmax = [&](size_t b, size_t h, size_t pq) {
auto cur_kv_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
auto ncausal = cur_kv_len;
// apply attention mask & sofmax
Expand All @@ -1210,7 +1241,16 @@ struct MHAHelper {
ov::element::f32,
ov::element::f32,
alibi_slope);
});
};

size_t h_dims = loop_hk ? _Hk : _H;
if (prefer_static_loop) {
parallel_for3d(B, kv_len_in_blocks, h_dims, loop_qk);
parallel_for3d(B, _H, q_len, loop_softmax);
} else {
parallel_for3d_dynamic(B, kv_len_in_blocks, h_dims, loop_qk);
parallel_for3d_dynamic(B, _H, q_len, loop_softmax);
}

if (output_score) {
parallel_for2d_dynamic(B, q_len, [&](size_t b, size_t pq) {
Expand All @@ -1229,16 +1269,19 @@ struct MHAHelper {
memset(_output_bhl.ptr<float>(ithr, 0, 0, 0, 0), 0, _output_bhl.stride(0) * sizeof(float));
});

parallel_for3d_dynamic(B, kv_len_in_blocks, _Hk, [&](size_t b, size_t pv_in_blocks, size_t hk) {
auto loop_wk = [&](size_t b, size_t pv_in_blocks, size_t hx) {
auto ithr = parallel_get_thread_num();
auto context_len = static_cast<size_t>(past_lens.ptr<int32_t>()[b]) + 1;
auto pv = pv_in_blocks * _block_size;
size_t hk, hq_beg, hq_end;
get_h_params(loop_hk, hx, _h_each_group_len, hq_beg, hq_end, hk);

// kv_len must be valid
if (pv < context_len) {
auto block_number = block_indices.ptr<int32_t>()[block_indices_begins.ptr<int32_t>()[b] + pv_in_blocks];
auto* v = present_value.ptr<KVCACHE_TYPE>(block_number, hk);
for (size_t pq = 0; pq < q_len; pq++) {
for (size_t h = hk * _h_each_group_len; h < (hk + 1) * _h_each_group_len; h++) {
for (size_t h = hq_beg; h < hq_end; h++) {
attn_acc_value_block(_output_bhl.ptr<float>(ithr, b, pq, h),
_weight_bhl.ptr<float>(b, h, pq) + pv,
v,
Expand All @@ -1247,7 +1290,13 @@ struct MHAHelper {
}
}
}
});
};

if (prefer_static_loop) {
parallel_for3d(B, kv_len_in_blocks, loop_hk ? _Hk : _H, loop_wk);
} else {
parallel_for3d_dynamic(B, kv_len_in_blocks, loop_hk ? _Hk : _H, loop_wk);
}

parallel_for3d(B, _H, q_len, [&](size_t b, size_t h, size_t pq) {
auto* temp = _output_bhl.ptr<float>(0, b, pq, h);
Expand Down Expand Up @@ -1416,7 +1465,23 @@ struct MHA {
}
});

parallel_for2d_dynamic(attn_work_count, Hk, [&](size_t w, size_t hk) {
// loop along HK dimension: if mixed first/second token and elements count is enough, loop HK to reuse KV in the CPU cache
// else if elements count is small, prefer to loop H to get more work to avoid thread imbalance
bool loop_hk = _workitems.get_reorder_max_batch_size() == past_lens.m_dims[0] || // if only first token, loop H
attn_work_count * Hk <= 2 * _helper._nthr ? false : true; // or less than 2 work items per thread, loop H

parallel_for2d_dynamic(attn_work_count, loop_hk ? Hk : _helper._H, [&](size_t w, size_t hx) {
size_t hk, hq_beg, hq_end;
if (loop_hk) {
hk = hx;
hq_beg = hk * _helper._h_each_group_len;
hq_end = (hk + 1) * _helper._h_each_group_len;
} else {
hq_beg = hx;
hq_end = hx + 1;
hk = hx / _helper._h_each_group_len;
}

const auto& item = _workitems.get_attn_work_item(w);
const auto batch_in_seq = item.batch_in_seq;
const auto batch_in_token = subsequence_begins.ptr<int32_t>()[batch_in_seq];
Expand All @@ -1434,7 +1499,7 @@ struct MHA {
_helper.exec_kernel_one_bh(q.slice(0, batch_in_token, batch_in_token), k_cache, v_cache,
output_emb.slice(0, batch_in_token, batch_in_token),
block_indices.ptr<int32_t>() + block_indices_begins.ptr<int32_t>()[batch_in_seq],
ithr, hk, 1ul, cur_kv_len, alibi_slopes,
ithr, hq_beg, hq_end, hk, 1ul, cur_kv_len, alibi_slopes,
score_output);
} else {
const auto batch_in_reorder = item.batch_in_reorder;
Expand All @@ -1461,6 +1526,8 @@ struct MHA {
block_indices.ptr<int32_t>() + block_indices_begins.ptr<int32_t>()[batch_in_seq],
ithr,
q_blk,
hq_beg,
hq_end,
hk,
q_len,
cur_kv_len,
Expand Down

0 comments on commit d62effb

Please sign in to comment.