Skip to content

Commit

Permalink
[GPU] PagedAttention improvements and fixes (openvinotoolkit#29248)
Browse files Browse the repository at this point in the history
### Details:
- Extended IncreasePositionIdsPrecision to support Matmul followed by
Reshape, improving accuracy for large context sizes
 - Added support for sliding window size PA parameter
- Fixed PA caching: added missing scale and fixed PA primitive
descriptor comparator
  • Loading branch information
sshlyapn authored Mar 4, 2025
1 parent da0a6d3 commit ea05578
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,18 @@ struct paged_attention : public primitive_base<paged_attention> {
}

bool operator==(const primitive& rhs) const override {
return compare_common_params(rhs);
if (!compare_common_params(rhs))
return false;

auto rhs_casted = downcast<const paged_attention>(rhs);

return head_size == rhs_casted.head_size &&
heads_num == rhs_casted.heads_num &&
kv_heads_num == rhs_casted.kv_heads_num &&
sliding_window == rhs_casted.sliding_window &&
has_alibi == rhs_casted.has_alibi &&
has_rotated_blocks == rhs_casted.has_rotated_blocks &&
scale_val.value_or(1.0f) == rhs_casted.scale_val.value_or(1.0f);
}

void save(BinaryOutputBuffer& ob) const override {
Expand All @@ -40,6 +51,14 @@ struct paged_attention : public primitive_base<paged_attention> {
ob << kv_heads_num;
ob << has_alibi;
ob << has_rotated_blocks;
ob << sliding_window;

if (scale_val.has_value()) {
ob << true;
ob << scale_val.value();
} else {
ob << false;
}
}

void load(BinaryInputBuffer& ib) override {
Expand All @@ -49,12 +68,24 @@ struct paged_attention : public primitive_base<paged_attention> {
ib >> kv_heads_num;
ib >> has_alibi;
ib >> has_rotated_blocks;
ib >> sliding_window;

bool has_scale;
ib >> has_scale;
if (has_scale) {
float scale = 1.0f;
ib >> scale;
scale_val = scale;
} else {
scale_val = std::optional<float>();
}
}

std::optional<float> scale_val{};
size_t head_size = 0;
size_t heads_num = 0;
size_t kv_heads_num = 0;
size_t sliding_window = 0;
bool has_alibi = false;
bool has_rotated_blocks = false;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
config.is_causal = true;
config.is_paged_attention = true;
config.paged_attention_block_size = static_cast<int64_t>(paged_attention::block_size);
config.paged_attention_sliding_window = desc->sliding_window;

if (desc->scale_val.has_value()) {
config.has_const_scale_val = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,11 @@ KERNEL(pa_sdpa_opt)(
qk_acc += alibi_slopes[head_num_idx] * alibi_val;
#endif

#if SLIDING_WINDOW_SIZE != 0
if (token_idx >= seq_len || (seq_len > SLIDING_WINDOW_SIZE && token_idx < (seq_len - SLIDING_WINDOW_SIZE)))
#else
if (token_idx >= seq_len)
#endif
qk_acc = SOFTMAX_ACCUMULATOR_VAL_MIN;

qk_max = SOFTMAX_ACCUMULATOR_MAX_FUNC(qk_max, TO_SOFTMAX_ACCUMULATOR_TYPE(qk_acc));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,12 @@ inline MASK_VECTOR_TYPE FUNC(load_attn_mask)(OPTIONAL_SHAPE_INFO_ARG
}
} else {
for (uint i = 0; i < SUBGROUP_SIZE; i++) {
#if defined(IS_PAGED_ATTENTION) && SLIDING_WINDOW_SIZE != 0
if ((source_seq_idx + i > target_seq_idx) ||
(target_seq_idx >= SLIDING_WINDOW_SIZE && source_seq_idx + i < target_seq_idx - SLIDING_WINDOW_SIZE))
#else
if (source_seq_idx + i > target_seq_idx)
#endif
mask_vec[i] = NAN;
}
}
Expand Down Expand Up @@ -1167,7 +1172,11 @@ MAKE_VECTOR_TYPE(INPUT0_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) qk_acc = INPUT0_VAL_ZER
unroll_for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
#if IS_CAUSAL
// casual mask: valid only if m >= n
#if defined(IS_PAGED_ATTENTION) && SLIDING_WINDOW_SIZE != 0
if ((seq_len + i <= target_seq_idx + sglid) && (target_seq_idx + sglid < SLIDING_WINDOW_SIZE || seq_len + i >= target_seq_idx + sglid - SLIDING_WINDOW_SIZE)) {
#else
if (seq_len + i <= target_seq_idx + sglid) {
#endif
#endif // IS_CAUSAL
#if !APPLY_SCALES_TO_QUERY
#if HAS_SCALE_INPUT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", seq_len_partition_size));
jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_BLOCK_SIZE", paged_attention_block_size));
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));
jit.AddConstant(MakeJitConstant("SLIDING_WINDOW_SIZE", config.paged_attention_sliding_window));

if (config.broadcast_axis != -1) {
jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", config.group_size));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ struct sdpa_configuration {

// Paged Attention configuration
bool is_paged_attention = false;
size_t paged_attention_sliding_window = 0;
int64_t paged_attention_aligned_seq_len = -1;
int64_t paged_attention_block_size = 0;
int64_t paged_attention_max_len = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke
jit.AddConstant(MakeJitConstant("SG_SCALE_FACTOR", get_sg_number_scale_factor(params, config.head_size, kernel_idx)));

if (params.conf.is_paged_attention) {
jit.AddConstant(MakeJitConstant("SLIDING_WINDOW_SIZE", params.conf.paged_attention_sliding_window));
if (params.conf.has_alibi_input) {
jit.AddConstant(MakeJitConstant("HAS_ALIBI", 1));
}
Expand Down
7 changes: 7 additions & 0 deletions src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
prim.heads_num = heads_num;

const size_t scale_idx = 9;
const size_t sliding_window_idx = 10;
const size_t alibi_idx = 11;

std::shared_ptr<ov::op::v0::Constant> scale_const = ov::as_type_ptr<ov::op::v0::Constant>(op->get_input_node_shared_ptr(scale_idx));
Expand All @@ -62,6 +63,12 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
prim.scale_val = std::optional<float>();
}

std::shared_ptr<ov::op::v0::Constant> sliding_windows_const = ov::as_type_ptr<ov::op::v0::Constant>(op->get_input_node_shared_ptr(sliding_window_idx));
if (sliding_windows_const) {
OPENVINO_ASSERT(ov::shape_size(sliding_windows_const->get_output_shape(0)) == 1);
prim.sliding_window = sliding_windows_const->cast_vector<size_t>()[0];
}

std::shared_ptr<ov::op::v0::Constant> alibi_const = ov::as_type_ptr<ov::op::v0::Constant>(op->get_input_node_shared_ptr(alibi_idx));
OPENVINO_ASSERT(alibi_const != nullptr);
prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ IncreasePositionIdsPrecision::IncreasePositionIdsPrecision() {

auto gemm_or_matmul = wrap_type<ov::intel_gpu::op::Gemm, ov::op::v0::MatMul>();
auto transpose_m = wrap_type<ov::op::v1::Transpose>({gemm_or_matmul, any_input()});
auto concat_input = std::make_shared<Or>(OutputVector{gemm_or_matmul, transpose_m});
auto reshape_m = wrap_type<ov::op::v1::Reshape>({gemm_or_matmul, any_input()});
auto concat_input = std::make_shared<Or>(OutputVector{gemm_or_matmul, transpose_m, reshape_m});
auto concat = wrap_type<ov::op::v0::Concat>({concat_input, concat_input});
auto sin = wrap_type<ov::op::v0::Sin>({concat});
auto cos = wrap_type<ov::op::v0::Cos>({concat});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,53 @@ TEST_F(TransformationTestsF, IncreasePositionIdsMatmulWithoutUnsqueeze) {
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}

TEST_F(TransformationTestsF, IncreasePositionIdsReshapeAfterMatmul) {
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1, -1 });
auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1}));
auto input_convert_fp = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f16);
auto rotary_embd_const = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1, 64, 1});
auto reshape_dims = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{3});

auto matmul = std::make_shared<ov::op::v0::MatMul>(input_convert_fp, rotary_embd_const);
auto reshape = std::make_shared<ov::op::v1::Reshape>(matmul, reshape_dims, true);
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{reshape, reshape}, 2);

auto cos = std::make_shared<ov::op::v0::Cos>(concat);
auto sin = std::make_shared<ov::op::v0::Sin>(concat);

auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4));
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos, sin}, ov::op::internal::RoPE::Config());

model = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input, reshape_dims });
manager.register_pass<IncreasePositionIdsPrecision>();
}
{
auto input = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{ -1, -1 });
auto input_convert = std::make_shared<ov::op::v0::Convert>(input, ov::element::i32);
auto input_unsqueeze = std::make_shared<ov::op::v0::Unsqueeze>(input_convert, std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{1}));
auto rotary_embd_const = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{1, 64, 1});
auto reshape_dims = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::PartialShape{3});

auto input_convert_f32 = std::make_shared<ov::op::v0::Convert>(input_unsqueeze, ov::element::f32);
auto rotary_embd_const_convert_f32 = std::make_shared<ov::op::v0::Convert>(rotary_embd_const, ov::element::f32);

auto matmul = std::make_shared<ov::op::v0::MatMul>(input_convert_f32, rotary_embd_const_convert_f32);
auto reshape = std::make_shared<ov::op::v1::Reshape>(matmul, reshape_dims, true);
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{reshape, reshape}, 2);

auto cos = std::make_shared<ov::op::v0::Cos>(concat);
auto sin = std::make_shared<ov::op::v0::Sin>(concat);

auto cos_convert = std::make_shared<ov::op::v0::Convert>(cos, ov::element::f16);
auto sin_convert = std::make_shared<ov::op::v0::Convert>(sin, ov::element::f16);

auto rope_input = std::make_shared<ov::op::v0::Parameter>(ov::element::f16, ov::PartialShape::dynamic(4));
auto rope = std::make_shared<ov::op::internal::RoPE>(ov::OutputVector{rope_input, cos_convert, sin_convert}, ov::op::internal::RoPE::Config());

model_ref = std::make_shared<ov::Model>(ov::NodeVector{ rope }, ov::ParameterVector{ input, rope_input, reshape_dims });
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
}

0 comments on commit ea05578

Please sign in to comment.