diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index c82853ec56e9ed..6b12f56215ca83 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -557,9 +557,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s } else { auto ListConstruct_452_Concat = makePattern({seq_length, {-1}, {head_cnt}, {ndims / 2}, {2}}, {{"axis", 0}}); + auto const_target_shape_0 = makeConst({0, 0, head_cnt, ndims / 2, 2}); auto const_target_shape_1 = makeConst({seq_len, batch, head_cnt, ndims / 2, 2}); - reshape_Reshape_453 = makePattern( - {slice_Slice_437 | var_split_1->output(0), ListConstruct_452_Concat | const_target_shape_1}); + reshape_Reshape_453 = + makePattern({slice_Slice_437 | var_split_1->output(0), + ListConstruct_452_Concat | const_target_shape_1 | const_target_shape_0}); } auto x_even = makePattern({reshape_Reshape_453, 0, -1}, {{"batch_dims", 0}}); @@ -588,6 +590,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s } else { auto ListConstruct_379_Concat = makePattern({seq_length, {-1}, {1}, {ndims / 2}, {2}}, {{"axis", 0}}); + auto const_target_shape_0 = makeConst({1, -1, 1, ndims / 2, 2}); auto const_target_shape_2 = makeConst({seq_len, batch, 1, ndims / 2, 2}); auto slice_Slice_449 = makePattern({cos_sin_cache, {0}, seq_length, {1}, {0}}); @@ -596,7 +599,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s // [seq_length, 1, batch, half_rotary_dims, 2] view_Reshape_460 = makePattern({slice_StridedSlice_449 | slice_Slice_449 | var_split_2->output(0), - ListConstruct_379_Concat | const_target_shape_2}, + ListConstruct_379_Concat | const_target_shape_0 | const_target_shape_2}, {{"special_zero", false}}); } @@ -609,12 +612,17 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s auto sub_Subtract_469 = makePattern({x_even_cos, neg_x_odd_sin}, {{"auto_broadcast", "numpy"}}); auto y_even = makePattern({sub_Subtract_469, -1}); + auto const_y_even_reshape = makeConst({1, -1, head_cnt, ndims / 2, 1}); + auto y_even_reshape = + makePattern({sub_Subtract_469, const_y_even_reshape}, {{"special_zero", false}}); auto x_odd_cos = makePattern({x_odd, cos_tab}, {{"auto_broadcast", "numpy"}}); auto x_even_sin = makePattern({x_even, sin_tab}, {{"auto_broadcast", "numpy"}}); auto add_Add_476 = makePattern({x_odd_cos, x_even_sin}, {{"auto_broadcast", "numpy"}}); auto y_odd = makePattern({add_Add_476, -1}); + auto const_y_odd_reshape = makeConst({1, -1, head_cnt, ndims / 2, 1}); + auto y_odd_reshape = makePattern({add_Add_476, const_y_odd_reshape}, {{"special_zero", false}}); - auto stack_481 = makePattern({y_even, y_odd}, {{"axis", -1}}); + auto stack_481 = makePattern({y_even | y_even_reshape, y_odd | y_odd_reshape}, {{"axis", -1}}); auto ShapeOf_135133 = makePattern({stack_481}); auto flatten_Slice_497 = GenSlice(ShapeOf_135133, 0, 3, 1, 0); @@ -629,9 +637,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s {{"special_zero", true}}); } else { // [length, batch, head_cnt, half_rotary_dims, 2] + auto const_target_shape_0 = makeConst({0, 0, head_cnt, ndims}); const_target_shape_3 = makeConst({seq_len, batch, head_cnt, ndims}); - flatten_Reshape_501 = makePattern({stack_481, flatten_Concat_500 | const_target_shape_3}, - {{"special_zero", true}}); + flatten_Reshape_501 = + makePattern({stack_481, flatten_Concat_500 | const_target_shape_3 | const_target_shape_0}, + {{"special_zero", true}}); } auto slice_Slice_443 = GenSlice(input_key, ndims, INT_MAX, 1, 3); diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index 0328831ff1a69c..1b34e0c4423d3d 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -1131,3 +1131,88 @@ TEST_F(TransformationTestsF, ConvertToROPE_Flux_mul_squeeze_unsqueeze) { } comparator.enable(FunctionsComparator::ATTRIBUTES); } + +TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) { + disable_rt_info_check(); + const int batch = -1; + const int seq_len = 1; + const int num_heads = 32; + const int num_heads_kv = 2; + const int ndims = 128; + const int rotary_ndims = 64; + const int hidden_size = ndims * (num_heads + 2 * num_heads_kv); + const int hidden_size_q = ndims * num_heads; + const int hidden_size_kv = ndims * num_heads_kv; + using namespace ov; + { + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{seq_len, batch, hidden_size}); + auto cos_sin = std::make_shared(ov::element::f32, + ov::PartialShape{seq_len, batch, rotary_ndims / 2, 2}); + auto aten_slice_Slice_1 = makeOP({cos_sin, {0}, {1}, {1}, {0}}); + auto aten_view_Reshape = makeOP({aten_slice_Slice_1, {seq_len, batch, 1, rotary_ndims / 2, 2}}, + {{"special_zero", false}}); + auto aten_select_Gather_1 = makeOP({aten_view_Reshape, 0, -1}, {{"batch_dims", 0}}); + auto aten_select_Gather_3 = makeOP({aten_view_Reshape, 1, -1}, {{"batch_dims", 0}}); + + auto attn_prim_ListUnpack = + makeOP({input, -1, {hidden_size_q, hidden_size_kv, hidden_size_kv}}); + auto attn_aten_view_Reshape_2 = + makeOP({attn_prim_ListUnpack->output(0), {0, 0, num_heads, ndims}}, + {{"special_zero", true}}); + auto VariadicSplit_29663 = + makeOP({attn_aten_view_Reshape_2, 3, {rotary_ndims, ndims - rotary_ndims}}); + auto aten_reshape_Reshape_55 = + makeOP({VariadicSplit_29663->output(0), {0, 0, num_heads, rotary_ndims / 2, 2}}, + {{"special_zero", true}}); + auto aten_select_Gather_440 = makeOP({aten_reshape_Reshape_55, 0, -1}, {{"batch_dims", 0}}); + auto aten_mul_Multiply_276 = + makeOP({aten_select_Gather_440, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}}); + auto aten_select_Gather_442 = makeOP({aten_reshape_Reshape_55, 1, -1}, {{"batch_dims", 0}}); + auto aten_mul_Multiply_277 = + makeOP({aten_select_Gather_442, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}}); + auto Multiply_34833 = + makeOP({aten_mul_Multiply_277, -1.000000f}, {{"auto_broadcast", "numpy"}}); + auto aten_sub_Subtract_55 = + makeOP({aten_mul_Multiply_276, Multiply_34833}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_62197 = makeOP({aten_sub_Subtract_55, {1, -1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_mul_Multiply_278 = + makeOP({aten_select_Gather_442, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}}); + auto aten_mul_Multiply_279 = + makeOP({aten_select_Gather_440, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}}); + auto aten_add_Add_55 = + makeOP({aten_mul_Multiply_278, aten_mul_Multiply_279}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_62198 = makeOP({aten_add_Add_55, {1, -1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_stack_55 = makeOP({Unsqueeze_62197, Unsqueeze_62198}, {{"axis", -1}}); + auto aten_flatten_Reshape_55 = + makeOP({aten_stack_55, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}}); + auto aten_cat_Concat_55 = + makeOP({aten_flatten_Reshape_55, VariadicSplit_29663->output(1)}, {{"axis", -1}}); + + model = std::make_shared(ov::NodeVector{aten_cat_Concat_55}, ov::ParameterVector{input, cos_sin}); + } + manager.register_pass(false); + { + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{seq_len, batch, hidden_size}); + auto gather_cos_sin = + std::make_shared(ov::element::f32, + ov::PartialShape{seq_len, batch, rotary_ndims / 2, 2}); + auto rope = makeOP({input, gather_cos_sin, gather_cos_sin}, + {{"config.slice_start", 0}, + {"config.slice_stop", 4096}, + {"config.input_trans0213", false}, + {"config.output_trans0213", false}, + {"config.is_interleaved", false}, + {"config.rotary_ndims", rotary_ndims}, + {"config.is_chatglm", true}, + {"config.support_2d_rope", false}, + {"config.is_qwen", false}, + {"config.head_cnt", num_heads}, + {"config.head_size", ndims}, + {"config.gather_position_arg_id", 0}}); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, gather_cos_sin}); + } +} \ No newline at end of file