Skip to content

Commit

Permalink
[GPU]Improve sdpa_opt kernel by skipping computes of causal mask (ope…
Browse files Browse the repository at this point in the history
…nvinotoolkit#28260)

### Details:
 - *skip computes of causal mask*
 - *...*

### Tickets:
 - *151857*

---------

Co-authored-by: Chen Peter <[email protected]>
  • Loading branch information
ceciliapeng2011 and peterchen-intel authored Jan 9, 2025
1 parent acae7cc commit 1694ea8
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
34 changes: 29 additions & 5 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,11 @@ KERNEL(sdpa_opt)(
#define b0_idx (batch_idx / NUM_HEADS)
#define b1_idx (batch_idx % NUM_HEADS)
#define target_seq_dim ((uint)get_global_id(1))
#if IS_PAGED_ATTENTION
#define target_seq_idx ((uint)block_start_pos - subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]])
#else
#define target_seq_idx ((uint)get_global_id(1) * TARGET_SEQ_LEN_BLOCK_SIZE)
#endif
#define head_size_idx ((uint)get_local_id(2) % HEAD_SIZE)
#define sglid (uint)get_sub_group_local_id()
#define sgid (uint)get_sub_group_id()
Expand Down Expand Up @@ -994,8 +998,15 @@ KERNEL(sdpa_opt)(
__attribute__((opencl_unroll_hint(1)))
for (uint start_partition_idx = 0; start_partition_idx < SOURCE_SEQ_LEN; start_partition_idx += SEQ_LEN_PARTITION_SIZE) {
const uint seq_len = start_partition_idx + sgid * SUBGROUP_SIZE;
#if IS_CAUSAL
const uint partition_seq_len = min((uint)SEQ_LEN_PARTITION_SIZE, (uint)max(0, (int)(target_seq_idx + seq_idx_end) - (int)start_partition_idx));
#else
const uint partition_seq_len = min((uint)SOURCE_SEQ_LEN - start_partition_idx, (uint)SEQ_LEN_PARTITION_SIZE);
#endif

#if IS_CAUSAL
if (seq_len <= target_seq_idx) { // keep tril i.e. m >= n
#endif
#if IS_PAGED_ATTENTION
#ifdef BROADCAST_GROUP_SIZE
const uint heads_dim = num_heads_dim / BROADCAST_GROUP_SIZE;
Expand Down Expand Up @@ -1026,21 +1037,21 @@ KERNEL(sdpa_opt)(
#endif

int seq_len_calc_size = min((int)(SOURCE_SEQ_LEN) - (int)seq_len, (int)SUBGROUP_SIZE);
#if IS_CAUSAL
MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZERO;
#else // !IS_CAUSAL
MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc;

qk_acc = FUNC_CALL(load_attn_mask)(OPTIONAL_SHAPE_INFO_TENSOR
b0_idx,
b1_idx,
#if IS_PAGED_ATTENTION
block_start_pos - subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]] + sglid,
#else
target_seq_idx + sglid,
#endif
// TODO: pass seq_len_calc_size here
seq_len
ATTN_MASK_BUFFER
ATTN_SCALE_BUFFER
PA_BUFFERS);
#endif // !IS_CAUSAL

if (seq_len_calc_size >= SUBGROUP_SIZE) {
#if IS_KV_COMPRESSED
Expand Down Expand Up @@ -1157,6 +1168,10 @@ KERNEL(sdpa_opt)(
{
SOFTMAX_ACCUMULATOR_TYPE qk_max = SOFTMAX_ACCUMULATOR_VAL_MIN;
unroll_for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
#if IS_CAUSAL
// casual mask: valid only if m >= n
if (seq_len + i <= target_seq_idx + sglid) {
#endif // IS_CAUSAL
#if !APPLY_SCALES_TO_QUERY
#if HAS_SCALE_INPUT
const OUTPUT_TYPE scale_val = *scale;
Expand All @@ -1172,12 +1187,21 @@ KERNEL(sdpa_opt)(
#endif

qk_acc[i] = INPUT0_MIN_FUNC(INPUT0_MAX_FUNC(qk_acc[i], INPUT0_VAL_MIN), INPUT0_VAL_MAX);

#if IS_CAUSAL
} else {
qk_acc[i] = INPUT0_VAL_MIN;
}
#endif // IS_CAUSAL
qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc[i]));
slm_qk_vals[sglid][sgid * TARGET_SEQ_LEN_BLOCK_SIZE + i] = qk_acc[i];
}
slm_qk_max_vals[sglid][sgid] = qk_max;
}
#if IS_CAUSAL
} else { // skip triu
slm_qk_max_vals[sglid][sgid] = SOFTMAX_ACCUMULATOR_VAL_MIN;
}
#endif

barrier(CLK_LOCAL_MEM_FENCE);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,22 @@ std::vector<Params> get_test_params() {
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, !causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});

/* -- causal mask -- */

p.push_back({with_rearrange, !with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
p.push_back({!with_rearrange, with_mask, !with_scale, causal, !compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});

// Beam search
p.push_back({with_rearrange, !with_mask, !with_scale, causal, !compressed, 2, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, !with_mask, !with_scale, causal, !compressed, 4, ov::element::Type_t::f16, 5, 16, 1, {0, 2, 1, 3}});

// Compressed
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 1, 1, {0, 1, 2, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 2, 1, 3}});
p.push_back({with_rearrange, with_mask, !with_scale, causal, compressed, 1, ov::element::Type_t::f16, 10, 4, 1, {0, 1, 2, 3}});
return p;
}

Expand Down

0 comments on commit 1694ea8

Please sign in to comment.