Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RopeFusion transformation after applying SDPA to PagedAttention conversion #28447

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
} else {
auto ListConstruct_452_Concat =
makePattern<opset1::Concat>({seq_length, {-1}, {head_cnt}, {ndims / 2}, {2}}, {{"axis", 0}});
auto const_target_shape_0 = makeConst({0, 0, head_cnt, ndims / 2, 2});
auto const_target_shape_1 = makeConst({seq_len, batch, head_cnt, ndims / 2, 2});
reshape_Reshape_453 = makePattern<opset1::Reshape>(
{slice_Slice_437 | var_split_1->output(0), ListConstruct_452_Concat | const_target_shape_1});
reshape_Reshape_453 =
makePattern<opset1::Reshape>({slice_Slice_437 | var_split_1->output(0),
ListConstruct_452_Concat | const_target_shape_1 | const_target_shape_0});
}

auto x_even = makePattern<opset8::Gather>({reshape_Reshape_453, 0, -1}, {{"batch_dims", 0}});
Expand Down Expand Up @@ -588,6 +590,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
} else {
auto ListConstruct_379_Concat =
makePattern<opset1::Concat>({seq_length, {-1}, {1}, {ndims / 2}, {2}}, {{"axis", 0}});
auto const_target_shape_0 = makeConst({1, -1, 1, ndims / 2, 2});
auto const_target_shape_2 = makeConst({seq_len, batch, 1, ndims / 2, 2});

auto slice_Slice_449 = makePattern<ov::opset8::Slice>({cos_sin_cache, {0}, seq_length, {1}, {0}});
Expand All @@ -596,7 +599,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
// [seq_length, 1, batch, half_rotary_dims, 2]
view_Reshape_460 =
makePattern<opset1::Reshape>({slice_StridedSlice_449 | slice_Slice_449 | var_split_2->output(0),
ListConstruct_379_Concat | const_target_shape_2},
ListConstruct_379_Concat | const_target_shape_0 | const_target_shape_2},
{{"special_zero", false}});
}

Expand All @@ -609,12 +612,17 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
auto sub_Subtract_469 = makePattern<opset1::Add>({x_even_cos, neg_x_odd_sin}, {{"auto_broadcast", "numpy"}});

auto y_even = makePattern<opset1::Unsqueeze>({sub_Subtract_469, -1});
auto const_y_even_reshape = makeConst({1, -1, head_cnt, ndims / 2, 1});
auto y_even_reshape =
makePattern<opset1::Reshape>({sub_Subtract_469, const_y_even_reshape}, {{"special_zero", false}});
auto x_odd_cos = makePattern<opset1::Multiply>({x_odd, cos_tab}, {{"auto_broadcast", "numpy"}});
auto x_even_sin = makePattern<opset1::Multiply>({x_even, sin_tab}, {{"auto_broadcast", "numpy"}});
auto add_Add_476 = makePattern<opset1::Add>({x_odd_cos, x_even_sin}, {{"auto_broadcast", "numpy"}});
auto y_odd = makePattern<opset1::Unsqueeze>({add_Add_476, -1});
auto const_y_odd_reshape = makeConst({1, -1, head_cnt, ndims / 2, 1});
auto y_odd_reshape = makePattern<opset1::Reshape>({add_Add_476, const_y_odd_reshape}, {{"special_zero", false}});

auto stack_481 = makePattern<opset1::Concat>({y_even, y_odd}, {{"axis", -1}});
auto stack_481 = makePattern<opset1::Concat>({y_even | y_even_reshape, y_odd | y_odd_reshape}, {{"axis", -1}});

auto ShapeOf_135133 = makePattern<opset1::ShapeOf>({stack_481});
auto flatten_Slice_497 = GenSlice(ShapeOf_135133, 0, 3, 1, 0);
Expand All @@ -629,9 +637,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s
{{"special_zero", true}});
} else {
// [length, batch, head_cnt, half_rotary_dims, 2]
auto const_target_shape_0 = makeConst({0, 0, head_cnt, ndims});
const_target_shape_3 = makeConst({seq_len, batch, head_cnt, ndims});
flatten_Reshape_501 = makePattern<opset1::Reshape>({stack_481, flatten_Concat_500 | const_target_shape_3},
{{"special_zero", true}});
flatten_Reshape_501 =
makePattern<opset1::Reshape>({stack_481, flatten_Concat_500 | const_target_shape_3 | const_target_shape_0},
{{"special_zero", true}});
}
auto slice_Slice_443 = GenSlice(input_key, ndims, INT_MAX, 1, 3);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1131,3 +1131,88 @@ TEST_F(TransformationTestsF, ConvertToROPE_Flux_mul_squeeze_unsqueeze) {
}
comparator.enable(FunctionsComparator::ATTRIBUTES);
}

TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) {
disable_rt_info_check();
const int batch = -1;
const int seq_len = 1;
const int num_heads = 32;
const int num_heads_kv = 2;
const int ndims = 128;
const int rotary_ndims = 64;
const int hidden_size = ndims * (num_heads + 2 * num_heads_kv);
const int hidden_size_q = ndims * num_heads;
const int hidden_size_kv = ndims * num_heads_kv;
using namespace ov;
{
auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{seq_len, batch, hidden_size});
auto cos_sin = std::make_shared<ov::opset1::Parameter>(ov::element::f32,
ov::PartialShape{seq_len, batch, rotary_ndims / 2, 2});
auto aten_slice_Slice_1 = makeOP<opset8::Slice>({cos_sin, {0}, {1}, {1}, {0}});
auto aten_view_Reshape = makeOP<opset1::Reshape>({aten_slice_Slice_1, {seq_len, batch, 1, rotary_ndims / 2, 2}},
{{"special_zero", false}});
auto aten_select_Gather_1 = makeOP<opset8::Gather>({aten_view_Reshape, 0, -1}, {{"batch_dims", 0}});
auto aten_select_Gather_3 = makeOP<opset8::Gather>({aten_view_Reshape, 1, -1}, {{"batch_dims", 0}});

auto attn_prim_ListUnpack =
makeOP<opset1::VariadicSplit>({input, -1, {hidden_size_q, hidden_size_kv, hidden_size_kv}});
auto attn_aten_view_Reshape_2 =
makeOP<opset1::Reshape>({attn_prim_ListUnpack->output(0), {0, 0, num_heads, ndims}},
{{"special_zero", true}});
auto VariadicSplit_29663 =
makeOP<opset1::VariadicSplit>({attn_aten_view_Reshape_2, 3, {rotary_ndims, ndims - rotary_ndims}});
auto aten_reshape_Reshape_55 =
makeOP<opset1::Reshape>({VariadicSplit_29663->output(0), {0, 0, num_heads, rotary_ndims / 2, 2}},
{{"special_zero", true}});
auto aten_select_Gather_440 = makeOP<opset8::Gather>({aten_reshape_Reshape_55, 0, -1}, {{"batch_dims", 0}});
auto aten_mul_Multiply_276 =
makeOP<opset1::Multiply>({aten_select_Gather_440, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}});
auto aten_select_Gather_442 = makeOP<opset8::Gather>({aten_reshape_Reshape_55, 1, -1}, {{"batch_dims", 0}});
auto aten_mul_Multiply_277 =
makeOP<opset1::Multiply>({aten_select_Gather_442, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}});
auto Multiply_34833 =
makeOP<opset1::Multiply>({aten_mul_Multiply_277, -1.000000f}, {{"auto_broadcast", "numpy"}});
auto aten_sub_Subtract_55 =
makeOP<opset1::Add>({aten_mul_Multiply_276, Multiply_34833}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_62197 = makeOP<opset1::Reshape>({aten_sub_Subtract_55, {1, -1, num_heads, rotary_ndims / 2, 1}},
{{"special_zero", false}});
auto aten_mul_Multiply_278 =
makeOP<opset1::Multiply>({aten_select_Gather_442, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}});
auto aten_mul_Multiply_279 =
makeOP<opset1::Multiply>({aten_select_Gather_440, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}});
auto aten_add_Add_55 =
makeOP<opset1::Add>({aten_mul_Multiply_278, aten_mul_Multiply_279}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_62198 = makeOP<opset1::Reshape>({aten_add_Add_55, {1, -1, num_heads, rotary_ndims / 2, 1}},
{{"special_zero", false}});
auto aten_stack_55 = makeOP<opset1::Concat>({Unsqueeze_62197, Unsqueeze_62198}, {{"axis", -1}});
auto aten_flatten_Reshape_55 =
makeOP<opset1::Reshape>({aten_stack_55, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}});
auto aten_cat_Concat_55 =
makeOP<opset1::Concat>({aten_flatten_Reshape_55, VariadicSplit_29663->output(1)}, {{"axis", -1}});

model = std::make_shared<ov::Model>(ov::NodeVector{aten_cat_Concat_55}, ov::ParameterVector{input, cos_sin});
}
manager.register_pass<ov::pass::RoPEFusion>(false);
{
auto input =
std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{seq_len, batch, hidden_size});
auto gather_cos_sin =
std::make_shared<ov::opset1::Parameter>(ov::element::f32,
ov::PartialShape{seq_len, batch, rotary_ndims / 2, 2});
auto rope = makeOP<ov::op::internal::RoPE>({input, gather_cos_sin, gather_cos_sin},
{{"config.slice_start", 0},
{"config.slice_stop", 4096},
{"config.input_trans0213", false},
{"config.output_trans0213", false},
{"config.is_interleaved", false},
{"config.rotary_ndims", rotary_ndims},
{"config.is_chatglm", true},
{"config.support_2d_rope", false},
{"config.is_qwen", false},
{"config.head_cnt", num_heads},
{"config.head_size", ndims},
{"config.gather_position_arg_id", 0}});
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, gather_cos_sin});
}
}
Loading