From 599465196fc9105ff95c2708b3d9528321e1d56b Mon Sep 17 00:00:00 2001 From: Vasily Shamporov Date: Thu, 9 Jan 2025 13:29:07 +0100 Subject: [PATCH] Add cache rotation inputs and CPU kernel implementation for cache rotation (#27088) Tickets: 153783 --- .github/workflows/job_cxx_unit_tests.yml | 6 + CMakeLists.txt | 2 +- .../core/offline_transformations.cpp | 9 +- .../state_management_pattern.hpp | 10 +- .../state_management_pattern.cpp | 39 +- .../openvino/pass/sdpa_to_paged_attention.hpp | 7 +- src/core/src/op/paged_attention.cpp | 40 +- src/core/src/pass/sdpa_to_paged_attention.cpp | 44 +- src/core/tests/type_prop/paged_attention.cpp | 130 ++++ .../kernels/scaled_attn/cache_rotation.hpp | 234 +++++++ .../src/nodes/kernels/scaled_attn/common.hpp | 56 +- .../nodes/kernels/scaled_attn/executor_pa.cpp | 149 ++++- .../scaled_attn/executor_pa_common.hpp | 31 +- .../intel_cpu/src/nodes/paged_attn.cpp | 20 +- .../intel_cpu/tests/unit/CMakeLists.txt | 5 + .../tests/unit/vectorized/CMakeLists.txt | 89 +++ .../vectorized/paged_attn_cache_rotation.cpp | 509 ++++++++++++++++ .../intel_cpu/tests/unit/vectorized/stub.cpp | 12 + .../intel_gpu/primitives/paged_attention.hpp | 7 +- .../src/graph/impls/ocl/paged_attention.cpp | 5 + .../intel_gpu/src/graph/paged_attention.cpp | 1 + .../kernel_selector/cl_kernels/pa_sdpa_opt.cl | 14 +- .../kernels/sdpa/pa_sdpa_kernel_opt.cpp | 3 + .../kernels/sdpa/sdpa_kernel_base.h | 1 + .../src/plugin/ops/paged_attention.cpp | 7 + .../generate_ref_diffs.py | 10 +- .../transformation_tests/sdpa2pa_ref_diff.py | 571 +++++++++--------- .../test_pa_transformation.py | 79 ++- 28 files changed, 1694 insertions(+), 396 deletions(-) create mode 100644 src/core/tests/type_prop/paged_attention.cpp create mode 100644 src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp create mode 100644 src/plugins/intel_cpu/tests/unit/vectorized/CMakeLists.txt create mode 100644 src/plugins/intel_cpu/tests/unit/vectorized/paged_attn_cache_rotation.cpp create mode 100644 src/plugins/intel_cpu/tests/unit/vectorized/stub.cpp diff --git a/.github/workflows/job_cxx_unit_tests.yml b/.github/workflows/job_cxx_unit_tests.yml index 52a2b3f4d287c8..a2a5762b4ea0bf 100644 --- a/.github/workflows/job_cxx_unit_tests.yml +++ b/.github/workflows/job_cxx_unit_tests.yml @@ -195,6 +195,12 @@ jobs: ${{ env.SETUPVARS_COMMAND }} ${{ env.INSTALL_TEST_DIR }}/ov_cpu_unit_tests --gtest_print_time=1 --gtest_output=xml:${{ env.INSTALL_TEST_DIR }}/TEST-CPUUnitTests.xml + - name: CPU plugin unit tests (vectorized) + if: fromJSON(inputs.affected-components).CPU.test + run: | + ${{ env.SETUPVARS_COMMAND }} + ${{ env.INSTALL_TEST_DIR }}/ov_cpu_unit_tests_vectorized --gtest_print_time=1 --gtest_output=xml:${{ env.INSTALL_TEST_DIR }}/TEST-CPUUnitTestsVectorized.xml + - name: ov_subgraphs_dumper_tests tests run: | ${{ env.SETUPVARS_COMMAND }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 65a72ef8f4936e..1cbdbe72507f6d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -185,4 +185,4 @@ endif() # provides a callback function to describe each component in repo include(cmake/packaging/packaging.cmake) -ov_cpack(${OV_CPACK_COMPONENTS_ALL}) \ No newline at end of file +ov_cpack(${OV_CPACK_COMPONENTS_ALL}) diff --git a/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp b/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp index 90aece1803f4b4..15a609f26d8fe9 100644 --- a/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp +++ b/src/bindings/python/src/pyopenvino/core/offline_transformations.cpp @@ -143,15 +143,18 @@ void regmodule_offline_transformations(py::module m) { m_offline_transformations.def( "paged_attention_transformation", - [](py::object& ie_api_model, bool use_block_indices_inputs, bool use_score_outputs) { + [](py::object& ie_api_model, bool use_block_indices_inputs, bool use_score_outputs, bool allow_cache_rotation) { const auto model = Common::utils::convert_to_model(ie_api_model); ov::pass::Manager manager; - manager.register_pass(use_block_indices_inputs, use_score_outputs); + manager.register_pass(use_block_indices_inputs, + use_score_outputs, + allow_cache_rotation); manager.run_passes(model); }, py::arg("model"), py::arg("use_block_indices_inputs") = false, - py::arg("use_score_outputs") = false); + py::arg("use_score_outputs") = false, + py::arg("allow_cache_rotation") = false); m_offline_transformations.def( "stateful_to_stateless_transformation", diff --git a/src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp b/src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp index 79b4f444cfa791..2e090a4aabaa30 100644 --- a/src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp +++ b/src/common/transformations/include/transformations/sdpa_to_paged_attention/state_management_pattern.hpp @@ -24,8 +24,12 @@ class ov::pass::StateManagementPattern : public ov::pass::MatcherPass { ParameterVector& parameters_to_remove, int& layer_index, ov::Output max_context_len, - ParameterVector& block_indices_inputs, + ParameterVector& block_indices_inputs_for_each_layer, ResultVector& score_results, - bool use_block_indices, - bool use_score_outputs); + bool use_per_layer_block_indices_inputs, + bool use_score_outputs, + bool allow_cache_rotation, + ParameterVector& rotated_block_indices_inputs_for_each_layer, + ParameterVector& rotation_deltas_inputs_for_each_layer, + std::shared_ptr model_rotation_trig_lut); }; diff --git a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp index a36085c34237a4..7b896463fdd51b 100644 --- a/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp +++ b/src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp @@ -15,6 +15,7 @@ #include "openvino/op/gather.hpp" #include "openvino/op/multiply.hpp" #include "openvino/op/paged_attention.hpp" +#include "openvino/op/parameter.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/scaled_dot_product_attention.hpp" #include "openvino/op/select.hpp" @@ -70,10 +71,14 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par ParameterVector& parameters_to_remove, int& layer_index, Output max_context_len, - ParameterVector& block_indices_inputs, + ParameterVector& block_indices_inputs_for_each_layer, ResultVector& score_results, - bool use_block_indices_inputs, - bool use_score_outputs) { + bool use_per_layer_block_indices_inputs, + bool use_score_outputs, + bool allow_cache_rotation, + ParameterVector& rotated_block_indices_inputs_for_each_layer, + ParameterVector& rotation_deltas_inputs_for_each_layer, + std::shared_ptr model_rotation_trig_lut) { MATCHER_SCOPE(StateManagementPattern); auto k_current = pattern::any_input(); @@ -176,9 +181,11 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par &model_remaining_params, &sliding_window, ¶meters_to_remove, - &block_indices_inputs, + &block_indices_inputs_for_each_layer, &score_results, - &layer_index](ov::pass::pattern::Matcher& m) { + &layer_index, + &rotated_block_indices_inputs_for_each_layer, + &rotation_deltas_inputs_for_each_layer](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); auto real_q = pattern_map.at(q); @@ -382,11 +389,27 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par max_context_len.get_node_shared_ptr()}; pa_arguments.insert(pa_arguments.end(), additional_params.begin(), additional_params.end()); - if (use_block_indices_inputs) { + if (use_per_layer_block_indices_inputs) { auto block_indices = setName(std::make_shared(element::i32, PartialShape{-1}), "block_indices." + std::to_string(layer_index - 1)); pa_arguments.insert(pa_arguments.begin() + 7, block_indices); - block_indices_inputs.push_back(block_indices); + block_indices_inputs_for_each_layer.push_back(block_indices); + } + + OPENVINO_ASSERT(pa_arguments.size() == 13); + + if (allow_cache_rotation) { + auto rotated_block_indices = setName(std::make_shared(element::i32, PartialShape{-1}), + "rotated_block_indices." + std::to_string(layer_index - 1)); + auto rotation_deltas = setName(std::make_shared(element::i32, PartialShape{-1, -1}), + "rotation_deltas." + std::to_string(layer_index - 1)); + + pa_arguments.insert(pa_arguments.begin() + 13, rotated_block_indices); + pa_arguments.insert(pa_arguments.begin() + 14, rotation_deltas); + pa_arguments.insert(pa_arguments.begin() + 15, model_rotation_trig_lut); + + rotated_block_indices_inputs_for_each_layer.push_back(rotated_block_indices); + rotation_deltas_inputs_for_each_layer.push_back(rotation_deltas); } auto paged_attention = std::make_shared(pa_arguments); @@ -444,4 +467,4 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par auto m = std::make_shared(sdpa_variants, matcher_name); register_matcher(m, callback); -} \ No newline at end of file +} diff --git a/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp index d52e78dbd6a489..b1b0bb6078d987 100644 --- a/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp +++ b/src/core/include/openvino/pass/sdpa_to_paged_attention.hpp @@ -19,12 +19,15 @@ class OPENVINO_API SDPAToPagedAttention : public ModelPass { public: OPENVINO_MODEL_PASS_RTTI("SDPAToPagedAttention"); - explicit SDPAToPagedAttention(bool use_block_indices_inputs = false, bool use_score_outputs = false); + explicit SDPAToPagedAttention(bool use_per_layer_block_indices_inputs = false, + bool use_score_outputs = false, + bool allow_cache_rotation = false); bool run_on_model(const std::shared_ptr& model) override; private: - bool m_use_block_indices_inputs; + bool m_use_per_layer_block_indices_inputs; bool m_use_score_outputs; + bool m_allow_cache_rotation; }; } // namespace pass } // namespace ov diff --git a/src/core/src/op/paged_attention.cpp b/src/core/src/op/paged_attention.cpp index cdcb66e86ee33e..1feeab44b7f018 100644 --- a/src/core/src/op/paged_attention.cpp +++ b/src/core/src/op/paged_attention.cpp @@ -19,8 +19,8 @@ void PagedAttentionExtension::validate_and_infer_types() { OV_OP_SCOPE(PagedAttentionExtension_validate_and_infer_types); NODE_VALIDATION_CHECK(this, - get_input_size() == 13, - "PagedAttensionExtension expects 13 inputs, but it has ", + get_input_size() == 13 || get_input_size() == 16, + "PagedAttensionExtension expects 13 or 16 inputs, but it has ", get_input_size()); NODE_VALIDATION_CHECK( @@ -147,6 +147,42 @@ void PagedAttentionExtension::validate_and_infer_types() { get_input_element_type(12), "."); + if (get_input_size() == 16) { + NODE_VALIDATION_CHECK( + this, + get_input_partial_shape(13).rank().is_dynamic() || get_input_partial_shape(13).rank().get_length() == 1, + "Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank ", + get_input_partial_shape(13).rank().get_length(), + "."); + NODE_VALIDATION_CHECK(this, + get_input_element_type(13).is_dynamic() || get_input_element_type(13) == element::i32, + "Element type of `rotated_block_indices` input should be i32, but it is ", + get_input_element_type(13), + "."); + NODE_VALIDATION_CHECK( + this, + get_input_partial_shape(14).rank().is_dynamic() || get_input_partial_shape(14).rank().get_length() == 2, + "Input `rotation_deltas` should either have rank 2 or be omitted, but it has rank ", + get_input_partial_shape(14).rank().get_length(), + "."); + NODE_VALIDATION_CHECK(this, + get_input_element_type(14).is_dynamic() || get_input_element_type(14) == element::i32, + "Element type of `rotation_deltas` input should be i32, but it is ", + get_input_element_type(14), + "."); + NODE_VALIDATION_CHECK( + this, + get_input_partial_shape(15).rank().is_dynamic() || get_input_partial_shape(15).rank().get_length() == 2, + "Input `rotation_trig_lut` should either have rank 2 or be omitted, but it has rank ", + get_input_partial_shape(15).rank().get_length(), + "."); + NODE_VALIDATION_CHECK(this, + get_input_element_type(15).is_dynamic() || get_input_element_type(15) == element::f32, + "Element type of `rotation_trig_lut` input should be f32, but it is ", + get_input_element_type(15), + "."); + } + // value head_size may be not same with key auto out_ps = get_input_partial_shape(0); const auto& key_ps = get_input_partial_shape(1); diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index e6fc744bb5ef4f..ea3f3c3e79e196 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -20,9 +20,12 @@ using namespace ov::op; -ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_block_indices_inputs, bool use_score_outputs) - : m_use_block_indices_inputs(use_block_indices_inputs), - m_use_score_outputs(use_score_outputs) {} +ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_indices_inputs, + bool use_score_outputs, + bool allow_cache_rotation) + : m_use_per_layer_block_indices_inputs(use_per_layer_block_indices_inputs), + m_use_score_outputs(use_score_outputs), + m_allow_cache_rotation(allow_cache_rotation) {} static std::shared_ptr setName(std::shared_ptr node, const char* name) { // Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a @@ -46,11 +49,18 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr(element::i32, PartialShape{-1}), "subsequence_begins"), setName(std::make_shared(element::i32, PartialShape{-1}), "block_indices_begins"), }; - if (!m_use_block_indices_inputs) { + if (!m_use_per_layer_block_indices_inputs) { auto block_indices = setName(std::make_shared(element::i32, PartialShape{-1}), "block_indices"); model_remaining_params.insert(model_remaining_params.begin() + 2, block_indices); } + std::shared_ptr model_rotation_trig_lut; + + if (m_allow_cache_rotation) { + model_rotation_trig_lut = + setName(std::make_shared(element::f32, PartialShape{-1, -1}), "rotation_trig_lut"); + } + auto sliding_window = v0::Constant::create(element::i32, Shape{}, {0}); // sliding_window auto get_parameter = [=](const std::shared_ptr& model, @@ -91,7 +101,10 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr position_ids; @@ -120,11 +133,14 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptroutput(0), - block_indices_inputs, + block_indices_inputs_for_each_layer, score_results, - m_use_block_indices_inputs, - m_use_score_outputs); - + m_use_per_layer_block_indices_inputs, + m_use_score_outputs, + m_allow_cache_rotation, + rotated_block_indices_inputs_for_each_layer, + rotation_deltas_inputs_for_each_layer, + model_rotation_trig_lut); manager.register_pass(unsqueezed_input_ids, max_context_len, position_ids); manager.register_pass(max_context_len); manager.register_pass(max_context_len); @@ -174,14 +190,20 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptrremove_parameter(parameter); } - if (m_use_block_indices_inputs) { - model->add_parameters(block_indices_inputs); + if (m_use_per_layer_block_indices_inputs) { + model->add_parameters(block_indices_inputs_for_each_layer); } if (m_use_score_outputs) { model->add_results(score_results); } + if (m_allow_cache_rotation) { + model->add_parameters(rotated_block_indices_inputs_for_each_layer); + model->add_parameters(rotation_deltas_inputs_for_each_layer); + model->add_parameters({model_rotation_trig_lut}); + } + model->add_parameters(kv_parameters); model->add_parameters(model_remaining_params); model->add_parameters({std::move(max_context_len)}); diff --git a/src/core/tests/type_prop/paged_attention.cpp b/src/core/tests/type_prop/paged_attention.cpp new file mode 100644 index 00000000000000..64fe26b32041ef --- /dev/null +++ b/src/core/tests/type_prop/paged_attention.cpp @@ -0,0 +1,130 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/op/paged_attention.hpp" + +#include + +#include "openvino/op/parameter.hpp" + +namespace ov { +namespace testing { + +TEST(type_prop, paged_attention_static_13_inputs) { + const auto query = std::make_shared(element::f32, PartialShape{3, 4}); + const auto key = std::make_shared(element::f32, PartialShape{3, 4}); + const auto value = std::make_shared(element::f32, PartialShape{3, 4}); + const auto key_cache = std::make_shared(element::f32, PartialShape{6, 2, 5, 4}); + const auto value_cache = std::make_shared(element::f32, PartialShape{6, 2, 5, 4}); + const auto past_lens = std::make_shared(element::i32, PartialShape{5}); + const auto subsequence_begins = std::make_shared(element::i32, PartialShape{5}); + const auto block_indices = std::make_shared(element::i32, PartialShape{15}); + const auto block_indices_begins = std::make_shared(element::i32, PartialShape{8}); + const auto scale = std::make_shared(element::f32, PartialShape{}); + const auto sliding_window = std::make_shared(element::i32, PartialShape{}); + const auto alibi_slopes = std::make_shared(element::f32, PartialShape{9}); + const auto max_context_len = std::make_shared(element::i32, PartialShape{}); + + ov::OutputVector args = {query, + key, + value, + key_cache, + value_cache, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len}; + const auto op = std::make_shared(args); + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{3, 4})); +} + +TEST(type_prop, paged_attention_static_16_inputs_eviction_per_block) { + const auto query = std::make_shared(element::f32, PartialShape{3, 4}); + const auto key = std::make_shared(element::f32, PartialShape{3, 4}); + const auto value = std::make_shared(element::f32, PartialShape{3, 4}); + const auto key_cache = std::make_shared(element::f32, PartialShape{6, 2, 5, 4}); + const auto value_cache = std::make_shared(element::f32, PartialShape{6, 2, 5, 4}); + const auto past_lens = std::make_shared(element::i32, PartialShape{5}); + const auto subsequence_begins = std::make_shared(element::i32, PartialShape{5}); + const auto block_indices = std::make_shared(element::i32, PartialShape{15}); + const auto block_indices_begins = std::make_shared(element::i32, PartialShape{8}); + const auto scale = std::make_shared(element::f32, PartialShape{}); + const auto sliding_window = std::make_shared(element::i32, PartialShape{}); + const auto alibi_slopes = std::make_shared(element::f32, PartialShape{9}); + const auto max_context_len = std::make_shared(element::i32, PartialShape{}); + + const auto rotated_block_indices = std::make_shared(element::i32, PartialShape{3}); + const auto rotation_deltas = std::make_shared(element::i32, PartialShape{12, 1}); + const auto rotation_trig_lut = std::make_shared(element::f32, PartialShape{256, 4}); + + ov::OutputVector args = {query, + key, + value, + key_cache, + value_cache, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + rotated_block_indices, + rotation_deltas, + rotation_trig_lut}; + + const auto op = std::make_shared(args); + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{3, 4})); +} + +TEST(type_prop, paged_attention_static_16_inputs_eviction_per_token) { + const auto query = std::make_shared(element::f32, PartialShape{3, 4}); + const auto key = std::make_shared(element::f32, PartialShape{3, 4}); + const auto value = std::make_shared(element::f32, PartialShape{3, 4}); + const auto key_cache = std::make_shared(element::f32, PartialShape{6, 2, 5, 4}); + const auto value_cache = std::make_shared(element::f32, PartialShape{6, 2, 5, 4}); + const auto past_lens = std::make_shared(element::i32, PartialShape{5}); + const auto subsequence_begins = std::make_shared(element::i32, PartialShape{5}); + const auto block_indices = std::make_shared(element::i32, PartialShape{15}); + const auto block_indices_begins = std::make_shared(element::i32, PartialShape{8}); + const auto scale = std::make_shared(element::f32, PartialShape{}); + const auto sliding_window = std::make_shared(element::i32, PartialShape{}); + const auto alibi_slopes = std::make_shared(element::f32, PartialShape{9}); + const auto max_context_len = std::make_shared(element::i32, PartialShape{}); + + const auto rotated_block_indices = std::make_shared(element::i32, PartialShape{3}); + const auto rotation_deltas = std::make_shared(element::i32, PartialShape{12, 5}); + const auto rotation_trig_lut = std::make_shared(element::f32, PartialShape{256, 4}); + + ov::OutputVector args = {query, + key, + value, + key_cache, + value_cache, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + rotated_block_indices, + rotation_deltas, + rotation_trig_lut}; + + const auto op = std::make_shared(args); + EXPECT_EQ(op->get_output_element_type(0), element::f32); + EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{3, 4})); +} + +} // namespace testing +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp new file mode 100644 index 00000000000000..552be63bd29a36 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/cache_rotation.hpp @@ -0,0 +1,234 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "common.hpp" +#include "openvino/openvino.hpp" + +#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# include +#endif + +#if defined(HAVE_AVX512F) +template +inline static void rotate_kv_cache_chunk_avx512(CT* current_x_values_ptr, + CT* current_y_values_ptr, + float* current_rotation_coeffts_cos_ptr, + float* current_rotation_coeffts_sin_ptr, + size_t num_vectorized_elements_per_iteration, + bool is_tail) { + using namespace ov::Extensions::Cpu::XARCH; + + auto result_x = _mm512_setzero_ps(); + auto result_y = _mm512_setzero_ps(); + + auto coeffts_cos = _mm512_undefined_ps(); + auto coeffts_sin = _mm512_undefined_ps(); + + auto cache_values_x = _mm512_undefined_ps(); + auto cache_values_y = _mm512_undefined_ps(); + + if (!is_tail) { + coeffts_cos = mm512_uni_loadu_ps(current_rotation_coeffts_cos_ptr); + coeffts_sin = mm512_uni_loadu_ps(current_rotation_coeffts_sin_ptr); + + cache_values_x = mm512_uni_loadu_ps(current_x_values_ptr); + cache_values_y = mm512_uni_loadu_ps(current_y_values_ptr); + } else { + coeffts_cos = mm512_uni_loadu_tail_ps(current_rotation_coeffts_cos_ptr, num_vectorized_elements_per_iteration); + coeffts_sin = mm512_uni_loadu_tail_ps(current_rotation_coeffts_sin_ptr, num_vectorized_elements_per_iteration); + + cache_values_x = mm512_uni_loadu_tail_ps(current_x_values_ptr, num_vectorized_elements_per_iteration); + cache_values_y = mm512_uni_loadu_tail_ps(current_y_values_ptr, num_vectorized_elements_per_iteration); + } + + result_x = _mm512_fmadd_ps(cache_values_x, coeffts_cos, result_x); + result_x = _mm512_fnmadd_ps(cache_values_y, coeffts_sin, result_x); // negative multiply-add + + result_y = _mm512_fmadd_ps(cache_values_x, coeffts_sin, result_y); + result_y = _mm512_fmadd_ps(cache_values_y, coeffts_cos, result_y); + + if (!is_tail) { + mm512_uni_storeu_ps(current_x_values_ptr, result_x); + mm512_uni_storeu_ps(current_y_values_ptr, result_y); + } else { + mm512_uni_storeu_tail_ps(current_x_values_ptr, result_x, num_vectorized_elements_per_iteration); + mm512_uni_storeu_tail_ps(current_y_values_ptr, result_y, num_vectorized_elements_per_iteration); + } +} +#endif + +#if defined(HAVE_AVX2) +template +inline static void rotate_kv_cache_chunk_avx2(CT* current_x_values_ptr, + CT* current_y_values_ptr, + float* current_rotation_coeffts_cos_ptr, + float* current_rotation_coeffts_sin_ptr, + size_t num_vectorized_elements_per_iteration, + size_t is_tail) { + using namespace ov::Extensions::Cpu::XARCH; + + auto result_x = _mm256_setzero_ps(); + auto result_y = _mm256_setzero_ps(); + + auto coeffts_cos = _mm256_undefined_ps(); + auto coeffts_sin = _mm256_undefined_ps(); + + auto cache_values_x = _mm256_undefined_ps(); + auto cache_values_y = _mm256_undefined_ps(); + + if (!is_tail) { + coeffts_cos = mm256_uni_loadu_ps(current_rotation_coeffts_cos_ptr); + coeffts_sin = mm256_uni_loadu_ps(current_rotation_coeffts_sin_ptr); + + cache_values_x = mm256_uni_loadu_ps(current_x_values_ptr); + cache_values_y = mm256_uni_loadu_ps(current_y_values_ptr); + } else { + coeffts_cos = mm256_uni_loadu_tail_ps(current_rotation_coeffts_cos_ptr, num_vectorized_elements_per_iteration); + coeffts_sin = mm256_uni_loadu_tail_ps(current_rotation_coeffts_sin_ptr, num_vectorized_elements_per_iteration); + + cache_values_x = mm256_uni_loadu_tail_ps(current_x_values_ptr, num_vectorized_elements_per_iteration); + cache_values_y = mm256_uni_loadu_tail_ps(current_y_values_ptr, num_vectorized_elements_per_iteration); + } + + result_x = _mm256_fmadd_ps(cache_values_x, coeffts_cos, result_x); + result_x = _mm256_fnmadd_ps(cache_values_y, coeffts_sin, result_x); // negative multiply-add + + result_y = _mm256_fmadd_ps(cache_values_x, coeffts_sin, result_y); + result_y = _mm256_fmadd_ps(cache_values_y, coeffts_cos, result_y); + + if (!is_tail) { + mm256_uni_storeu_ps(current_x_values_ptr, result_x); + mm256_uni_storeu_ps(current_y_values_ptr, result_y); + } else { + mm256_uni_storeu_tail_ps(current_x_values_ptr, result_x, num_vectorized_elements_per_iteration); + mm256_uni_storeu_tail_ps(current_y_values_ptr, result_y, num_vectorized_elements_per_iteration); + } +} +#endif + +template +inline static void rotate_kv_cache_block_opt(CT* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { +#if !defined(HAVE_AVX2) && !defined(HAVE_AVX512F) + OPENVINO_THROW("host CPU must support either AVX2 or AVX512 instructions"); +#else + bool is_tail = false; + +# if defined(HAVE_AVX512F) + constexpr size_t vec_len_in_f32_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx512; +# else // HAVE_AVX2 + constexpr size_t vec_len_in_f32_elts = ov::Extensions::Cpu::XARCH::vec_len_f32_avx2; +# endif // defined(HAVE_AVX512F) + + size_t num_processed_elements_per_iteration = + 2 * vec_len_in_f32_elts; // implementations act on pairs of cache values at once using separate registers, each + // elt is expanded to f32 on load + size_t num_iterations = embedding_size / num_processed_elements_per_iteration; + + if (embedding_size >= num_processed_elements_per_iteration) { + OPENVINO_ASSERT(!(num_processed_elements_per_iteration % vec_len_in_f32_elts)); + } else { + is_tail = true; + OPENVINO_ASSERT(!(embedding_size % 2)); + num_processed_elements_per_iteration = embedding_size; + num_iterations = 1; + } + + CT* current_cache_element_ptr = cache_block_ptr; + + for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { + // the rotation coefficients are taken to be the same for all heads + float* current_rotation_coeffts_ptr = block_rotation_coefficients_ptr; + for (size_t tok_idx = 0; tok_idx < block_size; + tok_idx++, current_cache_element_ptr += embedding_size, current_rotation_coeffts_ptr += embedding_size) { + CT* current_x_values_ptr = current_cache_element_ptr; + CT* current_y_values_ptr = current_cache_element_ptr + embedding_size / 2; + + float* current_rotation_coeffts_cos_ptr = current_rotation_coeffts_ptr; + float* current_rotation_coeffts_sin_ptr = current_rotation_coeffts_ptr + embedding_size / 2; + + for (size_t iter_idx = 0; iter_idx < num_iterations; iter_idx++, + current_x_values_ptr += vec_len_in_f32_elts, + current_y_values_ptr += vec_len_in_f32_elts, + current_rotation_coeffts_cos_ptr += vec_len_in_f32_elts, + current_rotation_coeffts_sin_ptr += vec_len_in_f32_elts) { +# if defined(HAVE_AVX512F) + rotate_kv_cache_chunk_avx512(current_x_values_ptr, + current_y_values_ptr, + current_rotation_coeffts_cos_ptr, + current_rotation_coeffts_sin_ptr, + num_processed_elements_per_iteration / 2, + is_tail); +# else // HAVE_AVX2 + rotate_kv_cache_chunk_avx2(current_x_values_ptr, + current_y_values_ptr, + current_rotation_coeffts_cos_ptr, + current_rotation_coeffts_sin_ptr, + num_processed_elements_per_iteration / 2, + is_tail); +# endif // defined(HAVE_AVX512F) + } + } + } +#endif // !defined(HAVE_AVX512F) && !defined(HAVE_AVX2F) +} + +template +inline static void rotate_kv_cache_block_ref(CT* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { + for (size_t head_idx = 0; head_idx < num_heads; head_idx++) { + for (size_t tok_idx = 0; tok_idx < block_size; tok_idx++) { + size_t token_offset = embedding_size * tok_idx; + CT* token_embedding_data_start_in_cache = + cache_block_ptr + head_idx * embedding_size * block_size + embedding_size * tok_idx; + float* token_data_start_in_rotation_coefficients = block_rotation_coefficients_ptr + token_offset; + for (size_t embedding_pair_idx = 0; embedding_pair_idx < embedding_size / 2; embedding_pair_idx++) { + // NB: below is the llama-style rotation (x-like values are in the first half of the embedding vector, + // y-like values are in the second half), which is different from the original RoFormer style (x- and y- + // values are interleaved), but still preserves the relative positional encoding property + CT* cache_value_0_ptr = token_embedding_data_start_in_cache + embedding_pair_idx; + CT* cache_value_1_ptr = cache_value_0_ptr + (embedding_size / 2); + + float rotation_value_cos = token_data_start_in_rotation_coefficients[embedding_pair_idx]; + float rotation_value_sin = + token_data_start_in_rotation_coefficients[embedding_pair_idx + (embedding_size / 2)]; + + CT cache_value_0 = *cache_value_0_ptr; + CT cache_value_1 = *cache_value_1_ptr; + + *cache_value_0_ptr = cache_value_0 * rotation_value_cos - cache_value_1 * rotation_value_sin; + *cache_value_1_ptr = cache_value_0 * rotation_value_sin + cache_value_1 * rotation_value_cos; + } + } + } +} + +template +inline static void rotate_kv_cache_block(CT* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { +#if defined(HAVE_AVX512F) || defined(HAVE_AVX2) + rotate_kv_cache_block_opt(cache_block_ptr, block_rotation_coefficients_ptr, num_heads, block_size, embedding_size); +#else + rotate_kv_cache_block_ref(cache_block_ptr, block_rotation_coefficients_ptr, num_heads, block_size, embedding_size); +#endif // defined(HAVE_AVX512F) || defined(HAVE_AVX2) +} + +template <> +inline void rotate_kv_cache_block(uint8_t* cache_block_ptr, + float* block_rotation_coefficients_ptr, + size_t num_heads, + size_t block_size, + size_t embedding_size) { + OPENVINO_THROW("cache rotation is not implemented for INT8"); +} diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index cb1cd24f840bfd..8b17b3ba8fb544 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -151,19 +151,30 @@ inline void mm512_uni_storeu_tail_ps(ov::float16* addr, __m512 v, size_t count) #endif #ifdef HAVE_AVX2 +inline __m128i get_8bit_tail_mask_for_16bit_elts(size_t num_16bit_tail_elts) { + // num_tail_elts may take from 0 to 8 + static int8_t masks[9][16] = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}}; + return _mm_loadu_si128(reinterpret_cast<__m128i*>(masks[num_16bit_tail_elts])); +} inline __m256i get_mask(int N7) { - static __m256i mask[] = { - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, 0), - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1), - _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1), - _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1), - _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1), - _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1), - _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1), - _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1), - _mm256_set_epi32(-1, -1, -1, -1, -1, -1, -1, -1), - }; - return _mm256_loadu_si256(&mask[N7]); + static int32_t masks[9][8] = {{0, 0, 0, 0, 0, 0, 0, 0}, + {-1, 0, 0, 0, 0, 0, 0, 0}, + {-1, -1, 0, 0, 0, 0, 0, 0}, + {-1, -1, -1, 0, 0, 0, 0, 0}, + {-1, -1, -1, -1, 0, 0, 0, 0}, + {-1, -1, -1, -1, -1, 0, 0, 0}, + {-1, -1, -1, -1, -1, -1, 0, 0}, + {-1, -1, -1, -1, -1, -1, -1, 0}, + {-1, -1, -1, -1, -1, -1, -1, -1}}; + return _mm256_loadu_si256(reinterpret_cast<__m256i*>(masks[N7])); } // load addr to __m256 reg @@ -207,7 +218,7 @@ inline void mm256_uni_storeu_ps(float* a, __m256 v) { _mm256_storeu_ps(a, v); } -inline void mm256_uni_storeu_ps(ov::bfloat16* addr, __m256 xps) { +inline __m128i __convert_avx2_packed_float_to_packed_ov_bfloat16(__m256 xps) { __m256i xpi32 = _mm256_castps_si256(xps); __m256i nan = _mm256_set1_epi32(0xffff); __m256i mask = _mm256_castps_si256(_mm256_cmp_ps(xps, xps, _CMP_ORD_Q)); @@ -220,6 +231,11 @@ inline void mm256_uni_storeu_ps(ov::bfloat16* addr, __m256 xps) { x = _mm256_packus_epi32(x, x); x = _mm256_permute4x64_epi64(x, 0xd8); __m128i bf16_o = _mm256_extractf128_si256(x, 0); + return bf16_o; +} + +inline void mm256_uni_storeu_ps(ov::bfloat16* addr, __m256 xps) { + __m128i bf16_o = __convert_avx2_packed_float_to_packed_ov_bfloat16(xps); _mm_storeu_si128(reinterpret_cast<__m128i*>(addr), bf16_o); } @@ -230,10 +246,22 @@ inline void mm256_uni_storeu_ps(ov::float16* a, __m256 v) { // store __m256 to addr inline void mm256_uni_storeu_tail_ps(float* addr, __m256 v, size_t count) { - const auto mask = get_mask(count); + auto mask = get_mask(count); return _mm256_maskstore_ps(addr, mask, v); } +inline void mm256_uni_storeu_tail_ps(ov::float16* addr, __m256 v, size_t count) { + auto mask = get_8bit_tail_mask_for_16bit_elts(count); + __m128i vec_f16 = _mm256_cvtps_ph(v, 0); + return _mm_maskmoveu_si128(vec_f16, mask, reinterpret_cast(addr)); +} + +inline void mm256_uni_storeu_tail_ps(ov::bfloat16* addr, __m256 v, size_t count) { + auto mask = get_8bit_tail_mask_for_16bit_elts(count); + __m128i bf16_o = __convert_avx2_packed_float_to_packed_ov_bfloat16(v); + return _mm_maskmoveu_si128(bf16_o, mask, reinterpret_cast(addr)); +} + inline void hsum(__m256& x) { __m256 y; // x: 0 1 2 3 4 5 6 7 y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index e67c0312bf67cc..ce95d825d44f50 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -16,6 +16,7 @@ #include "attn_memcpy.hpp" #include "attn_quant.hpp" #include "attn_quant_kernel.hpp" +#include "cache_rotation.hpp" #include "common.hpp" #include "executor_pa.hpp" #include "executor_pa_common.hpp" @@ -1150,6 +1151,66 @@ static void pack_32NxK(TDST* dst, OPENVINO_THROW("pack_32NxK: should not be called."); } +template +void fill_rotation_coefficients_from_lut(T* rotation_coefficients_block_data, + const int32_t* rotation_deltas_block_data, + size_t rotation_deltas_token_stride, + const T* rotation_trig_lut, + size_t block_size, + size_t embedding_size) { + size_t dst_offset = 0; + for (size_t tok_idx = 0; tok_idx < block_size; tok_idx++) { + size_t gather_idx = *(rotation_deltas_block_data + rotation_deltas_token_stride * tok_idx); + size_t src_offset = gather_idx * embedding_size; + std::memcpy(rotation_coefficients_block_data + dst_offset, + rotation_trig_lut + src_offset, + embedding_size * sizeof(T)); + dst_offset += embedding_size; + } +} + +template +void rotate_kv_cache(PlainTensor& key_cache, + const PlainTensor& rotated_block_indices, + const PlainTensor& rotation_deltas, + const PlainTensor& rotation_trig_lut, + PlainTensor& rotation_coefficients_scratch) { + size_t num_blocks_in_total = key_cache.size(0); + size_t num_heads = key_cache.size(1); // H; + size_t block_size = key_cache.size(2); + size_t embedding_size = key_cache.size(3); // S; + + size_t num_rotated_blocks = rotated_block_indices.size(0); + int32_t* rotated_block_indices_data = rotated_block_indices.ptr(); + float* rotation_trig_lut_data = rotation_trig_lut.ptr(); + + size_t rotation_deltas_token_stride = 0; + size_t rotation_deltas_block_stride = 1; + + bool is_per_token = (rotation_deltas.shape()[1] == block_size); + if (is_per_token) { + rotation_deltas_token_stride = 1; + rotation_deltas_block_stride = block_size; + } + + for (size_t i = 0; i < num_rotated_blocks; i++) { + size_t rotated_block_index = *(rotated_block_indices_data + i); + OPENVINO_ASSERT(rotated_block_index < num_blocks_in_total); + + int32_t* rotation_deltas_block_data = rotation_deltas.ptr() + i * rotation_deltas_block_stride; + + float* rotation_coefficient_block_data = rotation_coefficients_scratch.ptr(); + fill_rotation_coefficients_from_lut(rotation_coefficient_block_data, + rotation_deltas_block_data, + rotation_deltas_token_stride, + rotation_trig_lut_data, + block_size, + embedding_size); + KVCACHE_TYPE* cache_block_ptr = key_cache.ptr(rotated_block_index); + rotate_kv_cache_block(cache_block_ptr, rotation_coefficient_block_data, num_heads, block_size, embedding_size); + } +} + template struct MHAHelper { // initialize once @@ -1189,6 +1250,8 @@ struct MHAHelper { PlainTensor _score_offsets_aligned; PlainTensor _score_offsets; + PlainTensor _block_rotation_coefficient_scratch; + MHAHelper() { _weight.resize({size_t{1}, size_t{1}, size_t{1}, size_t{1}}); } @@ -1208,7 +1271,8 @@ struct MHAHelper { size_t sliding_window, float d_scale, size_t kv_len, - bool init_alibi_lookup) { + bool init_alibi_lookup, + bool init_rotation_coefficient_scratch) { // query shape: [B, H, L, S] // present_key shape: [block, H, 32, S] // Q*K': [M1, S] * [M2, S]' @@ -1306,6 +1370,10 @@ struct MHAHelper { for (size_t i = 0; i < _alibi_lookup.m_dims[0]; i++) _alibi_lookup.ptr()[i] = -static_cast((_alibi_lookup.m_dims[0] - 1 - i)); } + + if (init_rotation_coefficient_scratch) { + _block_rotation_coefficient_scratch.resize({_block_size, S}); + } } void init_reorder_buffers(size_t batch, size_t kv_len_in_blocks) { @@ -1584,12 +1652,13 @@ struct MHAHelper { // batch tokens. It will assume NO mixture execution of first and second token. all tensors such as query... have // batch dimension which is DIFFERENT from above // query: [B, H, L, S] - // present_*: [block_number, H, 32, S] + // key_cache: [block_number, H, _block_size, S] + // value_cache: [block_number, H, _block_size, Sv] // output_emb: [B, L, H * S] // 3 loops along batch, head, kv cache length dimensions void exec_loop_bhl(const PlainTensor& query, - const PlainTensor& present_key, - const PlainTensor& present_value, + PlainTensor& key_cache, + PlainTensor& value_cache, const PlainTensor& output_emb, const PlainTensor& output_score, size_t max_context_len, @@ -1647,7 +1716,7 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { (*_gemv)(query.ptr(b, h, pq), - present_key.ptr(block_number, hk), + key_cache.ptr(block_number, hk), _weight_bhl.ptr(b, h, pq) + pk); } } @@ -1656,7 +1725,7 @@ struct MHAHelper { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { dot_product_block(query.ptr(b, h, pq), - present_key.ptr(block_number, hk), + key_cache.ptr(block_number, hk), _weight_bhl.ptr(b, h, pq) + pk, _S, std::min(_block_size, context_len - pk), @@ -1729,12 +1798,11 @@ struct MHAHelper { auto block_number = block_indices.ptr()[block_indices_begins.ptr()[b] + pv_in_blocks]; for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - auto sub_byte_multiplier = get_sub_byte_multiplier(present_value.get_precision()); - size_t v_stride = - (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) * - present_value.get_precision().size() / sub_byte_multiplier; + auto sub_byte_multiplier = get_sub_byte_multiplier(value_cache.get_precision()); + size_t v_stride = (block_number * value_cache.m_strides[0] + hk * value_cache.m_strides[1]) * + value_cache.get_precision().size() / sub_byte_multiplier; auto* v_ptr = reinterpret_cast::value_type*>( - present_value.m_ptr.get() + v_stride); + value_cache.m_ptr.get() + v_stride); attn_acc_value_block::value_type, VALUE_PREC>( _output_bhl.ptr(ithr, b, pq, h), _weight_bhl.ptr(b, h, pq) + pv, @@ -1868,7 +1936,7 @@ struct MHA { // one loop to handle first and second tokens void exec_loop_mixed(const PlainTensor& q, - const PlainTensor& k_cache, + PlainTensor& k_cache, const PlainTensor& v_cache, const PlainTensor& output_emb, const PlainTensor& output_score, @@ -2117,6 +2185,9 @@ struct AttentionExecutor : public PagedAttentionExecutor { size_t& sliding_window, PlainTensor& alibi_slopes, size_t& max_context_len, + PlainTensor& rotated_block_indices, + PlainTensor& rotation_deltas, + PlainTensor& rotation_trig_lut, PlainTensor& output_emb, PlainTensor& output_score) { q.reset(inputs[ID_Q]); // [B_token, H * S] @@ -2133,6 +2204,19 @@ struct AttentionExecutor : public PagedAttentionExecutor { if (!inputs[ID_ALIBI_SLOPES]->getShape().hasZeroDims()) alibi_slopes.reset(inputs[ID_ALIBI_SLOPES]); max_context_len = static_cast(*inputs[ID_MAX_CONTEXT_LEN]->getDataAs()); + + size_t inputs_size = inputs.size(); + if (inputs_size > ID_ROTATED_BLOCK_INDICES) { + OPENVINO_ASSERT(inputs_size >= ID_ROTATION_TRIG_LUT); + if (!inputs[ID_ROTATED_BLOCK_INDICES]->getShape().hasZeroDims()) + rotated_block_indices.reset(inputs[ID_ROTATED_BLOCK_INDICES]); // [num_blocks] + if (!inputs[ID_ROTATION_DELTAS]->getShape().hasZeroDims()) + rotation_deltas.reset(inputs[ID_ROTATION_DELTAS]); // [num_blocks, block_size (32) || 1] + if (!inputs[ID_ROTATION_TRIG_LUT]->getShape().hasZeroDims()) + rotation_trig_lut.reset( + inputs[ID_ROTATION_TRIG_LUT]); // [max_context_len * embedding_size], row-major layout + } + output_emb.reset(outputs[0]); if (outputs.size() == 2) output_score.reset(outputs[1]); @@ -2189,13 +2273,34 @@ struct AttentionExecutor : public PagedAttentionExecutor { if (alibi_slopes) { alibi_slopes.assert_dims({H}); } + + bool init_rotation_coefficient_scratch = false; + if (rotated_block_indices) { + // Only K entries are needed to be rotated, since position is encoded at the Q^T @ (effective_RoPE_matrix) @ + // K matrix multiplication + rotation_deltas.assert_dims({rotated_block_indices.size(0), 0}, /* special_zero = */ true); + OPENVINO_ASSERT(rotation_deltas.shape()[1] == 1 || + rotation_deltas.shape()[1] == block_size); // per-block or per-token granularity + rotation_trig_lut.assert_dims({0, S}, /* special_zero = */ true); + init_rotation_coefficient_scratch = true; + } output_emb.assert_dims({B_token, H * SV}); output_emb = output_emb.reshape({B_token, 1, H * SV}); // TODO: enable block_size to be multiple of 32 OPENVINO_ASSERT(block_size == 32, "CPU: block size must be 32, current: ", block_size); - _helper.init(H, S, SV, Hk, h_each_group_len, block_size, sliding_window, scale, max_context_len, alibi_slopes); + _helper.init(H, + S, + SV, + Hk, + h_each_group_len, + block_size, + sliding_window, + scale, + max_context_len, + alibi_slopes, + init_rotation_coefficient_scratch); } void concat_pastkv(const PlainTensor& k, @@ -2244,6 +2349,10 @@ struct AttentionExecutor : public PagedAttentionExecutor { size_t sliding_window; PlainTensor alibi_slopes; size_t max_context_len; + PlainTensor rotated_block_indices; + PlainTensor rotation_deltas; + PlainTensor rotation_trig_lut; + PlainTensor output_emb; PlainTensor output_score; @@ -2262,8 +2371,20 @@ struct AttentionExecutor : public PagedAttentionExecutor { sliding_window, alibi_slopes, max_context_len, + rotated_block_indices, + rotation_deltas, + rotation_trig_lut, output_emb, output_score); + + if (rotated_block_indices) { + rotate_kv_cache(k_cache, + rotated_block_indices, + rotation_deltas, + rotation_trig_lut, + _helper._block_rotation_coefficient_scratch); + } + concat_pastkv(k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins); _kernel(q, @@ -2367,4 +2488,4 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ } // namespace XARCH } // namespace Cpu } // namespace Extensions -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp index 81c54c84d9453a..66911d4f4e7b1f 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp @@ -21,19 +21,22 @@ namespace Cpu { struct PagedAttentionExecutor { // PagedAttention input index - static const size_t ID_Q = 0; // [B_token, H * S], float - static const size_t ID_K = 1; // [B_token, Hk * S], float - static const size_t ID_V = 2; // [B_token, Hk * S], float - static const size_t ID_KCACHE = 3; // [block_number, H, block_size, S], float - static const size_t ID_VCACHE = 4; // [block_number, H, block_size, S], float - static const size_t ID_PAST_LENS = 5; // [B_seq] - static const size_t ID_SUBSEQUENCE_BEGINS = 6; // [B_seq+1] - static const size_t ID_BLOCK_INDICES = 7; // [num_blocks] - static const size_t ID_BLOCK_INDICES_BEGINS = 8; // [B_seq+1] - static const size_t ID_SCALE = 9; // [], float - static const size_t ID_SLIDING_WINDOW = 10; // [] - static const size_t ID_ALIBI_SLOPES = 11; // [H|0], float - static const size_t ID_MAX_CONTEXT_LEN = 12; // [] + static const size_t ID_Q = 0; // [B_token, H * S], float + static const size_t ID_K = 1; // [B_token, Hk * S], float + static const size_t ID_V = 2; // [B_token, Hk * S], float + static const size_t ID_KCACHE = 3; // [block_number, H, block_size, S], float + static const size_t ID_VCACHE = 4; // [block_number, H, block_size, S], float + static const size_t ID_PAST_LENS = 5; // [B_seq] + static const size_t ID_SUBSEQUENCE_BEGINS = 6; // [B_seq+1] + static const size_t ID_BLOCK_INDICES = 7; // [num_blocks] + static const size_t ID_BLOCK_INDICES_BEGINS = 8; // [B_seq+1] + static const size_t ID_SCALE = 9; // [], float + static const size_t ID_SLIDING_WINDOW = 10; // [] + static const size_t ID_ALIBI_SLOPES = 11; // [H|0], float + static const size_t ID_MAX_CONTEXT_LEN = 12; // [] + static const size_t ID_ROTATED_BLOCK_INDICES = 13; // [num_rotated_blocks || 0], int32 + static const size_t ID_ROTATION_DELTAS = 14; // [num_rotated_blocks * block_size || 0], int32 + static const size_t ID_ROTATION_TRIG_LUT = 15; // [max_context_length * S || 0], f32 virtual void execute(const std::vector& inputs, const std::vector outputs) = 0; virtual ~PagedAttentionExecutor() = default; @@ -107,4 +110,4 @@ class JitMatMulVecAMX : public dnnl::impl::cpu::x64::jit_generator { } // namespace Cpu } // namespace Extensions -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp index 54aa80e9dff7c0..b1632c34ff6fa2 100644 --- a/src/plugins/intel_cpu/src/nodes/paged_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/paged_attn.cpp @@ -82,7 +82,8 @@ void PagedAttention::initSupportedPrimitiveDescriptors() { creatorsMap.at(LayoutType::ncsp) ->createSharedDesc(rtPrecision, getInputShapeAtPort(PagedAttentionExecutor::ID_V))); - OPENVINO_ASSERT(orgInputNumber == 13, "The input number of PagedAttention should be 13."); + OPENVINO_ASSERT(orgInputNumber == 13 || orgInputNumber == 16, + "The input number of PagedAttention should be 13 or 16."); // kvcache, float, [] auto past_key_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_KCACHE); auto past_value_input_mem_precision = getOriginalInputPrecisionAtPort(PagedAttentionExecutor::ID_VCACHE); @@ -130,6 +131,23 @@ void PagedAttention::initSupportedPrimitiveDescriptors() { config.outConfs[1].setMemDesc( creatorsMap.at(LayoutType::ncsp)->createSharedDesc(ov::element::f32, getOutputShapeAtPort(1))); + if (orgInputNumber == 16) { + // rotated_block_indices, int, [num_rotated_blocks || 0] + config.inConfs[PagedAttentionExecutor::ID_ROTATED_BLOCK_INDICES].setMemDesc( + creatorsMap.at(LayoutType::ncsp) + ->createSharedDesc(ov::element::i32, + getInputShapeAtPort(PagedAttentionExecutor::ID_ROTATED_BLOCK_INDICES))); + // rotation_deltas, int, [num_rotated_blocks, block_size || 1] || [0] + config.inConfs[PagedAttentionExecutor::ID_ROTATION_DELTAS].setMemDesc( + creatorsMap.at(LayoutType::ncsp) + ->createSharedDesc(ov::element::i32, getInputShapeAtPort(PagedAttentionExecutor::ID_ROTATION_DELTAS))); + // rotation_trig_lut, float, [max_context_len, embedding_size (aka S) || 0] + config.inConfs[PagedAttentionExecutor::ID_ROTATION_TRIG_LUT].setMemDesc( + creatorsMap.at(LayoutType::ncsp) + ->createSharedDesc(ov::element::f32, + getInputShapeAtPort(PagedAttentionExecutor::ID_ROTATION_TRIG_LUT))); + } + supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any); } diff --git a/src/plugins/intel_cpu/tests/unit/CMakeLists.txt b/src/plugins/intel_cpu/tests/unit/CMakeLists.txt index 63441b504735b0..81645f4fc87553 100644 --- a/src/plugins/intel_cpu/tests/unit/CMakeLists.txt +++ b/src/plugins/intel_cpu/tests/unit/CMakeLists.txt @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 # +add_subdirectory(vectorized) + set(TARGET_NAME ov_cpu_unit_tests) if(BUILD_SHARED_LIBS) @@ -52,6 +54,8 @@ ov_add_test_target( $/include EXCLUDED_SOURCE_PATHS ${EXCLUDED_SOURCE_PATHS_FOR_UNIT_TEST} + ${CMAKE_CURRENT_SOURCE_DIR}/vectorized + OBJECT_FILES ${OBJ_LIB} LINK_LIBRARIES @@ -78,6 +82,7 @@ if (ENABLE_SNIPPETS_LIBXSMM_TPP) target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $) endif() + # LTO set_target_properties(${TARGET_NAME} PROPERTIES INTERPROCEDURAL_OPTIMIZATION_RELEASE ${ENABLE_LTO}) diff --git a/src/plugins/intel_cpu/tests/unit/vectorized/CMakeLists.txt b/src/plugins/intel_cpu/tests/unit/vectorized/CMakeLists.txt new file mode 100644 index 00000000000000..7e64c24f604a85 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/vectorized/CMakeLists.txt @@ -0,0 +1,89 @@ +# Copyright (C) 2018-2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# + +set(TARGET_NAME ov_cpu_unit_tests_vectorized) + +if(BUILD_SHARED_LIBS) + set (OBJ_LIB $) +endif() + +if(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + ov_add_compiler_flags(/wd5051) +endif() + +if(NOT X86_64) + list(APPEND EXCLUDED_SOURCE_PATHS_FOR_UNIT_TEST + ${CMAKE_CURRENT_SOURCE_DIR}/paged_attn_cache_rotation.cpp) +else() + list(APPEND EXCLUDED_SOURCE_PATHS_FOR_UNIT_TEST + ${CMAKE_CURRENT_SOURCE_DIR}/stub.cpp) +endif() + +if (ENABLE_MLAS_FOR_CPU) + set(MLAS_LIBRARY "mlas") +endif() + +if (ENABLE_SHL_FOR_CPU) + set(SHL_LIBRARY "shl") +endif() + +ov_add_test_target( + NAME ${TARGET_NAME} + ROOT ${CMAKE_CURRENT_SOURCE_DIR} + INCLUDES + PUBLIC + $/src + $/src/nodes + $ + PRIVATE + $/include + EXCLUDED_SOURCE_PATHS + ${EXCLUDED_SOURCE_PATHS_FOR_UNIT_TEST} + OBJECT_FILES + ${OBJ_LIB} + LINK_LIBRARIES + gtest + gtest_main + dnnl + gmock + openvino_runtime_s + unit_test_utils + ov_snippets_models + snippets_test_utils + ${MLAS_LIBRARY} + ${SHL_LIBRARY} + ADD_CPPLINT + LABELS + OV UNIT CPU +) + + +if (ENABLE_SNIPPETS_LIBXSMM_TPP) + add_definitions(-DSNIPPETS_LIBXSMM_TPP -DLIBXSMM_DEFAULT_CONFIG) + target_compile_definitions(xsmm PRIVATE __BLAS=0) + target_link_libraries(${TARGET_NAME} PRIVATE xsmm) + target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $) +endif() + +if (X86_64) + ov_avx2_optimization_flags(avx2_flags) + ov_avx512_optimization_flags(avx512_flags) + + target_compile_options(${TARGET_NAME} PRIVATE "${avx2_flags};${avx512_flags}") + target_compile_definitions(${TARGET_NAME} PRIVATE HAVE_AVX2 HAVE_AVX512F) +endif() + + +if (WIN32) + # Prevents defining min/max as macros + target_compile_definitions(${TARGET_NAME} PRIVATE NOMINMAX) +endif() + +target_include_directories(${TARGET_NAME} SYSTEM PRIVATE + $) + +target_include_directories(${TARGET_NAME} SYSTEM PRIVATE + $/src/common + $/src/cpu + $/include) diff --git a/src/plugins/intel_cpu/tests/unit/vectorized/paged_attn_cache_rotation.cpp b/src/plugins/intel_cpu/tests/unit/vectorized/paged_attn_cache_rotation.cpp new file mode 100644 index 00000000000000..870c5c576a73e1 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/vectorized/paged_attn_cache_rotation.cpp @@ -0,0 +1,509 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include + +#include +#include +#include + +// the includes in the block below are necessary in order for the common.hpp header to be +// instantiated correctly +#include +#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# include +#endif +#include "kernels/scaled_attn/common.hpp" +#include "nodes/kernels/scaled_attn/cache_rotation.hpp" +#include "perf_count.h" +#include "utils/plain_tensor.hpp" + +using namespace ov::intel_cpu; + +template +using Rank2Matrix = std::vector>; + +template +using Rank3Matrix = std::vector>>; + +// Expected layout: [block_size, embedding_size] +template +std::vector get_block_memory(size_t block_size, size_t embedding_size, const Rank2Matrix& init_values) { + auto mem = std::vector(block_size * embedding_size); + if (!init_values.empty()) { + assert(init_values.size() == block_size); + assert(init_values[0].size() == embedding_size); + for (size_t i = 0; i < block_size; i++) { + for (size_t j = 0; j < embedding_size; j++) { + mem[i * embedding_size + j] = init_values[i][j]; + } + } + } + return mem; +} + +// Expected layout: [num_heads, block_size, embedding_size] +template +std::vector get_block_memory(size_t num_heads, + size_t block_size, + size_t embedding_size, + const Rank3Matrix& init_values) { + auto mem = std::vector(num_heads * block_size * embedding_size); + if (!init_values.empty()) { + assert(init_values.size() == num_heads); + assert(init_values[0].size() == block_size); + assert(init_values[0][0].size() == embedding_size); + for (size_t i = 0; i < num_heads; i++) { + for (size_t j = 0; j < block_size; j++) { + for (size_t k = 0; k < embedding_size; k++) { + mem[i * embedding_size * block_size + j * embedding_size + k] = init_values[i][j][k]; + } + } + } + } + return mem; +} + +template +Rank3Matrix get_matrix_from_mem(std::vector mem_vec, + size_t num_heads, + size_t block_size, + size_t embedding_size) { + Rank3Matrix retval(num_heads); + for (size_t i = 0; i < num_heads; i++) { + retval[i].resize(block_size); + for (size_t j = 0; j < block_size; j++) { + retval[i][j].resize(embedding_size); + } + } + for (size_t i = 0; i < num_heads; i++) { + for (size_t j = 0; j < block_size; j++) { + for (size_t k = 0; k < embedding_size; k++) { + retval[i][j][k] = mem_vec[block_size * embedding_size * i + embedding_size * j + k]; + } + } + } + return retval; +} + +template +void compare_with_tolerance(const Rank3Matrix& test_data, const Rank3Matrix& ref_data, T abs_err) { + ASSERT_EQ(test_data.size(), ref_data.size()); + ASSERT_GT(test_data.size(), 0); + + ASSERT_EQ(test_data[0].size(), ref_data[0].size()); + ASSERT_GT(test_data[0].size(), 0); + + ASSERT_EQ(test_data[0][0].size(), ref_data[0][0].size()); + ASSERT_GT(test_data[0][0].size(), 0); + + for (size_t i = 0; i < test_data.size(); i++) { + for (size_t j = 0; j < test_data[0].size(); j++) { + for (size_t k = 0; k < test_data[0][0].size(); k++) { + T diff = test_data[i][j][k] - ref_data[i][j][k]; + if ((diff > abs_err) || (diff < -abs_err)) { + ADD_FAILURE() << std::setprecision(8) << "diff " << diff << " exceeding atol " << abs_err + << " at idx [" << i << ";" << j << ";" << k << "] --- test " << test_data[i][j][k] + << ", ref " << ref_data[i][j][k]; + } + } + } + } +} + +template +static T get_tolerance() { + return T{}; +} + +template <> +float get_tolerance() { + return 1e-6f; +} + +template <> +ov::float16 get_tolerance() { + return ov::float16{5e-3}; +} + +template <> +ov::bfloat16 get_tolerance() { + return ov::bfloat16{4e-2}; +} + +template +class CacheRotationKernelInputTypeParameterizedTest : public ::testing::Test { +public: + void SetUp() override { + Rank3Matrix values_before_rotation = { + { + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + }, + { + {-2.0f, -2.0f, -2.0f, -2.0f}, + {2.0f, 2.0f, 2.0f, 2.0f}, + {-1.0f, 2.0f, -3.0f, 4.0f}, + {2.0f, 2.0f, 2.0f, 2.0f}, + }, + }; + cache_mem = get_block_memory(num_heads, block_size, embedding_size, values_before_rotation); + + Rank2Matrix rotation_values = { + {0.5f, 0.70710678f, 0.86602540f, -0.70710678f}, + {0.86602540f, 1.0f, 0.5f, 0.0f}, + {-0.70710678f, 0.0f, 0.70710678f, 1.0f}, + {0.0f, 0.6f, -1.0f, -0.8f}, + }; + + rotation_coefficients_mem = get_block_memory(block_size, embedding_size, rotation_values); + } + size_t num_heads = 2; + size_t block_size = 4; + size_t embedding_size = 4; + std::vector cache_mem; + std::vector rotation_coefficients_mem; + Rank3Matrix ref_values_after_rotation = { + { + {-0.36602540f, 1.41421356f, 1.36602540f, 0.00000000f}, + {0.36602540f, 1.00000000f, 1.36602540f, 1.00000000f}, + {-1.41421356f, -1.00000000f, 0.00000000f, 1.00000000f}, + {1.00000000f, 1.40000000f, -1.00000000f, -0.20000000f}, + }, + { + {0.73205081f, -2.82842712f, -2.73205081f, 0.00000000f}, + {0.73205081f, 2.00000000f, 2.73205081f, 2.00000000f}, + {2.82842712f, -4.00000000f, 1.41421356f, 2.00000000f}, + {2.00000000f, 2.80000000f, -2.00000000f, -0.40000000f}, + }, + }; + + void test_block_opt_vs_ref(size_t num_heads, size_t embedding_size, size_t block_size) { + auto cache_block_mem_ref = get_block_memory(num_heads, block_size, embedding_size, Rank3Matrix{}); + auto rotation_coeffts_block_mem = get_block_memory(block_size, embedding_size, Rank2Matrix{}); + + std::mt19937 engine; + engine.seed(0); + std::uniform_real_distribution rng(-2.0, 2.0); + + auto generate_fn = [&]() { + return TypeParam(rng(engine)); + }; + + std::generate(cache_block_mem_ref.begin(), cache_block_mem_ref.end(), generate_fn); + // coeffts are now not strictly sine-cosine pairs, but it does not matter for the kernels + std::generate(rotation_coeffts_block_mem.begin(), + rotation_coeffts_block_mem.end(), + generate_fn); + + + + auto cache_block_mem_hw = cache_block_mem_ref; + + auto raw_mem_ptr_ref = cache_block_mem_ref.data(); + auto raw_rotation_coefficients_mem_ptr = rotation_coeffts_block_mem.data(); + auto raw_mem_ptr_hw = cache_block_mem_hw.data(); + + ov::intel_cpu::PerfCount counter; + { + ov::intel_cpu::PerfHelper helper(counter); + rotate_kv_cache_block_opt(raw_mem_ptr_hw, + raw_rotation_coefficients_mem_ptr, + num_heads, + block_size, + embedding_size); + } + + { + ov::intel_cpu::PerfHelper helper(counter); + rotate_kv_cache_block_ref(raw_mem_ptr_ref, + raw_rotation_coefficients_mem_ptr, + num_heads, + block_size, + embedding_size); + } + + auto ref_values_after_rotation = get_matrix_from_mem(cache_block_mem_ref, num_heads, block_size, embedding_size); + auto opt_values_after_rotation = get_matrix_from_mem(cache_block_mem_hw, num_heads, block_size, embedding_size); + compare_with_tolerance(opt_values_after_rotation, ref_values_after_rotation, get_tolerance()); + } +}; + +using OV_FP_TYPES = ::testing::Types; + +TYPED_TEST_SUITE_P(CacheRotationKernelInputTypeParameterizedTest); + +TYPED_TEST_P(CacheRotationKernelInputTypeParameterizedTest, RefBlockRotationGivesReferenceResults) { + auto raw_cache_mem_ptr = this->cache_mem.data(); + auto raw_rotation_coefficients_mem_ptr = this->rotation_coefficients_mem.data(); + + rotate_kv_cache_block_ref(raw_cache_mem_ptr, + raw_rotation_coefficients_mem_ptr, + this->num_heads, + this->block_size, + this->embedding_size); + + auto test_values_after_rotation = + get_matrix_from_mem(this->cache_mem, this->num_heads, this->block_size, this->embedding_size); + compare_with_tolerance(test_values_after_rotation, this->ref_values_after_rotation, get_tolerance()); +} + +enum class TargetInstructionSet { AVX2, AVX512 }; + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wsuggest-override" // false positive in gtest macro internals +#endif + +MATCHER_P3(IsNFirstValuesNear, ref_container, abs_err, n, "") { + if (ref_container.size() < n || arg.size() < n) + return false; + if (ref_container.size() != arg.size()) + return false; + + bool is_ok = true; + for (size_t i = 0; i < n; i++) { + if (!::testing::ExplainMatchResult(::testing::FloatNear(static_cast(arg[i]), abs_err), + static_cast(ref_container[i]), + result_listener)) { + *result_listener << " for element at idx " << i << '\n'; + is_ok = false; + } + } + return is_ok; +} + + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +class CacheRotationKernelInstructionParameterizedTest + : public ::testing::TestWithParam> { +protected: + constexpr static size_t MAX_CHUNK_SIZE_IN_ELEMENTS = 16; + template + using MemChunk = std::array; + + template + void test_chunk_rotation_for_type() { + auto instruction_set = std::get<0>(GetParam()); + if (instruction_set == TargetInstructionSet::AVX512 && (!ov::with_cpu_x86_avx512f())) { + GTEST_SKIP() << "test executor must have AVX512 support"; + } + if (instruction_set == TargetInstructionSet::AVX2 && (!ov::with_cpu_x86_avx2())) { + GTEST_SKIP() << "test executor must have AVX2 support"; + } + auto num_elements_to_process = std::get<1>(GetParam()); + + MemChunk chunk_x = {-0.76777814f, + 0.97583583f, + -0.23619731f, + 0.19022397f, + 0.56691264f, + 0.64870757f, + 0.63334306f, + 1.97307894f, + 0.72495168f, + 1.22328697f, + -0.6005607f, + 0.17189973f, + -0.92268487f, + 0.40205632f, + 0.85996431f, + 1.70078315f}; + + MemChunk chunk_y = {1.68812157f, + -0.90722836f, + 0.58474063f, + -0.64561766f, + 0.62651501f, + 1.55990472f, + 0.41571189f, + 0.38366555f, + 0.09841767f, + 0.02218336f, + -0.07657361f, + 1.6062845f, + -1.08282323f, + -0.92034808f, + -1.48428038f, + 0.43501142f}; + + MemChunk chunk_cos = {-0.87461971f, + 0.95630476f, + 0.08715574f, + 0.8480481f, + -0.9612617f, + 0.27563736f, + 0.97437006f, + 0.66913061f, + -0.89100652f, + 0.98480775f, + -0.7313537f, + -0.2419219f, + 0.10452846f, + 0.70710678f, + -0.32556815f, + -0.2923717f}; + + MemChunk chunk_sin = {-0.48480962f, + -0.2923717f, + 0.9961947f, + 0.52991926f, + 0.27563736f, + -0.9612617f, + -0.22495105f, + 0.74314483f, + 0.4539905f, + -0.17364818f, + -0.68199836f, + -0.97029573f, + -0.9945219f, + -0.70710678f, + -0.94551858f, + 0.95630476f}; + + MemChunk ref_chunk_cos = chunk_cos; + MemChunk ref_chunk_sin = chunk_sin; + + MemChunk ref_chunk_x = {1.48993147f, + 0.66794854f, + -0.60310147f, + 0.50344431f, + -0.71764235f, + 1.6782847f, + 0.71062535f, + 1.03512844f, + -0.69061736f, + 1.20855459f, + 0.38699921f, + 1.51698468f, + -1.17333824f, + -0.36648762f, + -1.68339166f, + -0.91326436f}; + + MemChunk ref_chunk_y = {-1.10423816f, + -1.15289358f, + -0.184335f, + -0.44671148f, + -0.44598258f, + -0.19360973f, + 0.26258603f, + 1.72300577f, + 0.24143039f, + -0.19057521f, + 0.46558381f, + -0.55538896f, + 0.80444446f, + -0.93508112f, + -0.32987781f, + 1.49928198f}; + + // unprocessed elements should remain untouched + std::copy(chunk_x.begin() + num_elements_to_process, + chunk_x.end(), + ref_chunk_x.begin() + num_elements_to_process); + std::copy(chunk_y.begin() + num_elements_to_process, + chunk_y.end(), + ref_chunk_y.begin() + num_elements_to_process); + + switch (instruction_set) { + using namespace ov::Extensions::Cpu::XARCH; + case TargetInstructionSet::AVX2: + rotate_kv_cache_chunk_avx2(chunk_x.data(), + chunk_y.data(), + chunk_cos.data(), + chunk_sin.data(), + num_elements_to_process, + /* is_tail = */ num_elements_to_process < vec_len_f32_avx2); + break; + case TargetInstructionSet::AVX512: + rotate_kv_cache_chunk_avx512(chunk_x.data(), + chunk_y.data(), + chunk_cos.data(), + chunk_sin.data(), + num_elements_to_process, + /* is_tail = */ num_elements_to_process < vec_len_f32_avx512); + break; + default: + FAIL() << "unknown target instruction set"; + } + + std::string type_name = ov::element::from().to_string(); + + EXPECT_THAT(chunk_x, IsNFirstValuesNear(ref_chunk_x, get_tolerance(), num_elements_to_process)) + << ", element type is: " << type_name; + EXPECT_THAT(chunk_y, IsNFirstValuesNear(ref_chunk_y, get_tolerance(), num_elements_to_process)) + << ", element type is: " << type_name; + + EXPECT_EQ(chunk_cos, ref_chunk_cos) << ", element type is: " << type_name; + EXPECT_EQ(chunk_sin, ref_chunk_sin) << ", element type is: " << type_name; + } +}; + +TEST_P(CacheRotationKernelInstructionParameterizedTest, OptChunkRotationGivesReferenceResults) { + test_chunk_rotation_for_type(); + test_chunk_rotation_for_type(); + test_chunk_rotation_for_type(); +} + +auto TEST_STRUCT_TO_NAME_FN = + [](const testing::TestParamInfo& info) { + size_t num_elts = std::get<1>(info.param); + switch (std::get<0>(info.param)) { + case TargetInstructionSet::AVX2: + return std::string("avx2-") + std::to_string(num_elts); + case TargetInstructionSet::AVX512: + return std::string("avx512-") + std::to_string(num_elts); + } + return std::string("unknown"); + }; + +INSTANTIATE_TEST_SUITE_P(AVX2, + CacheRotationKernelInstructionParameterizedTest, + ::testing::Combine(::testing::Values(TargetInstructionSet::AVX2), + ::testing::Range(size_t(0), + ov::Extensions::Cpu::XARCH::vec_len_f32_avx2 + 1)), + TEST_STRUCT_TO_NAME_FN); +INSTANTIATE_TEST_SUITE_P(AVX512, + CacheRotationKernelInstructionParameterizedTest, + ::testing::Combine(::testing::Values(TargetInstructionSet::AVX512), + ::testing::Range(size_t(0), + ov::Extensions::Cpu::XARCH::vec_len_f32_avx512 + 1)), + TEST_STRUCT_TO_NAME_FN); + +TYPED_TEST_P(CacheRotationKernelInputTypeParameterizedTest, OptBlockRotationGivesReferenceResults) { + auto raw_cache_mem_ptr = this->cache_mem.data(); + auto raw_rotation_coefficients_mem_ptr = this->rotation_coefficients_mem.data(); + + rotate_kv_cache_block_opt(raw_cache_mem_ptr, + raw_rotation_coefficients_mem_ptr, + this->num_heads, + this->block_size, + this->embedding_size); + + auto test_values_after_rotation = + get_matrix_from_mem(this->cache_mem, this->num_heads, this->block_size, this->embedding_size); + compare_with_tolerance(test_values_after_rotation, this->ref_values_after_rotation, get_tolerance()); +} + +TYPED_TEST_P(CacheRotationKernelInputTypeParameterizedTest, OptBlockRotationIsSimilarToRef) { + // short case + this->test_block_opt_vs_ref(/* num_heads = */ 4, /* embedding_size = */ 64, /* block_size = */ 2); + + // long case + this->test_block_opt_vs_ref(256, 1024, 32); +} + +REGISTER_TYPED_TEST_SUITE_P(CacheRotationKernelInputTypeParameterizedTest, + RefBlockRotationGivesReferenceResults, + OptBlockRotationGivesReferenceResults, + OptBlockRotationIsSimilarToRef); +INSTANTIATE_TYPED_TEST_SUITE_P(AllFPTypes, CacheRotationKernelInputTypeParameterizedTest, OV_FP_TYPES); diff --git a/src/plugins/intel_cpu/tests/unit/vectorized/stub.cpp b/src/plugins/intel_cpu/tests/unit/vectorized/stub.cpp new file mode 100644 index 00000000000000..2c6aba41d2231d --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/vectorized/stub.cpp @@ -0,0 +1,12 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +TEST(StubTest, AlwaysPass) { + // Some target platforms for the vectorized tests do not have any cases right now. + // In order to make the build pass on these platforms, the build system will include this + // file as the only source for the ov_cpu_unit_tests_vectorized, and the test binary + // will always pass the run. +} diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index 2638f2ad60cf26..ad79e5178f21a8 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -21,7 +21,9 @@ struct paged_attention : public primitive_base { paged_attention(const primitive_id& id, const std::vector& inputs) : primitive_base(id, inputs) { - OPENVINO_ASSERT(inputs.size() == 13, "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size()); + OPENVINO_ASSERT((inputs.size() == 13) || (inputs.size() == 15), + "[GPU] Unexpected inputs number for PagedAttention primitive: ", + inputs.size()); } bool has_scores_output() const { @@ -38,6 +40,7 @@ struct paged_attention : public primitive_base { ob << heads_num; ob << kv_heads_num; ob << has_alibi; + ob << has_rotated_blocks; } void load(BinaryInputBuffer& ib) override { @@ -46,6 +49,7 @@ struct paged_attention : public primitive_base { ib >> heads_num; ib >> kv_heads_num; ib >> has_alibi; + ib >> has_rotated_blocks; } optional_value scale_val{}; @@ -53,5 +57,6 @@ struct paged_attention : public primitive_base { size_t heads_num = 0; size_t kv_heads_num = 0; bool has_alibi = false; + bool has_rotated_blocks = false; }; } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp index 2bc377f2c1459a..15a1632a8a2b1f 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp @@ -204,6 +204,11 @@ struct paged_attention_impl : multi_stage_primitive { // dependency args.inputs.push_back(instance.subsequence_begins_memory_ptr()); } + if (desc->has_rotated_blocks) { + args.inputs.push_back(instance.rotated_block_indices_memory_ptr()); + args.inputs.push_back(instance.rotation_deltas_memory_ptr()); + args.inputs.push_back(instance.rotation_trig_lut_memory_ptr()); + } } else if (kernel_idx == 4) { // Output scores calculation kernel args.inputs = { instance.past_lens_memory_ptr(), diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index c761aaf63799cd..48ae46d83de34a 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -98,6 +98,7 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) { paged_attention_info.add("kv_heads_num", desc->kv_heads_num); paged_attention_info.add("scale", desc->scale_val.value_or(1.0f)); paged_attention_info.add("has_alibi", desc->has_alibi); + paged_attention_info.add("has_rotated_blocks", desc->has_rotated_blocks); node_info->add("paged_attention primitive info", paged_attention_info); node_info->dump(primitive_description); diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl index 7e960afa4b87d3..2d6598e0a654cc 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl @@ -42,6 +42,11 @@ KERNEL(pa_sdpa_opt)( #endif #if HAS_ALIBI const __global ALIBI_INPUT_TYPE* alibi_slopes, +#endif +#if HAS_ROTATED_BLOCKS + const __global INPUT8_TYPE* rotated_block_indices, + const __global INPUT9_TYPE* rotation_deltas, + const __global INPUT10_TYPE* rotated_block_indices, #endif __global OUTPUT_TYPE* output, #if PAGED_ATTENTION_SCORES_OUTPUT @@ -62,7 +67,10 @@ KERNEL(pa_sdpa_opt)( // past_lens: [sequences_num] // subsequence_begins: [sequences_num + 1] // block_indices: [used_blocks_num] - // block_indices: [sequences_num + 1] + // block_indices_begins: [sequences_num + 1] + // rotated_block_indices: [num_rotated_blocks ] + // rotation_deltas [num_rotated_blocks, 1 || PAGED_ATTENTION_BLOCK_SIZE ] + // rotation_trig_lut [MAX_CONTEXT_LEN, HEAD_SIZE] // // Output shapes: // output: [sequences_num, HEADS_NUM * HEAD_SIZE] @@ -148,6 +156,10 @@ KERNEL(pa_sdpa_opt)( } #endif +#ifdef HAS_ROTATED_BLOCKS + // TODO (vshampor): add cache block rotation at this spot +#endif + const uint blocks_num_per_partition = min(total_blocks_num - partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION, (uint)PAGED_ATTENTION_BLOCKS_PER_PARTITION); uint blocks_num = blocks_num_per_partition / SUBGROUPS_PER_WG; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp index 909a40d677f535..bac6ebd11fbe9b 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp @@ -237,6 +237,9 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params& jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_SCORES_OUTPUT", 1)); } + if (params.conf.has_rotated_blocks) + jit.AddConstant(MakeJitConstant("HAS_ROTATED_BLOCKS", 1)); + if (kernel_idx == KernelsTypes::MULTI_TOKENS || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS) jit.AddConstant(MakeJitConstant("MULTI_TOKENS_PROCESSING", 1)); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h index 8fcc4a16692d6c..7b9519395d88ca 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h @@ -100,6 +100,7 @@ struct sdpa_configuration { int64_t paged_attention_max_len = 0; bool has_const_scale_val = false; float scale_val = 0.f; + bool has_rotated_blocks = false; }; //////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index d82d3a66fed7f7..b56807d720b870 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -48,6 +48,7 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared const size_t scale_idx = 9; const size_t alibi_idx = 11; + const size_t rotated_block_indices_idx = 13; std::shared_ptr scale_const = std::dynamic_pointer_cast(op->get_input_node_shared_ptr(scale_idx)); if (scale_const) { @@ -62,6 +63,12 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0; prim.num_outputs = 1; + + std::shared_ptr rotated_block_indices_const = + std::dynamic_pointer_cast(op->get_input_node_shared_ptr(rotated_block_indices_idx)); + OPENVINO_ASSERT(rotated_block_indices_const != nullptr); + prim.has_rotated_blocks = ov::shape_size(rotated_block_indices_const->get_output_shape(0)) > 0; + if (op->get_output_size() > 1) { const auto scores_output_idx = 1; const auto& users = op->get_output_target_inputs(scores_output_idx); diff --git a/tests/model_hub_tests/transformation_tests/generate_ref_diffs.py b/tests/model_hub_tests/transformation_tests/generate_ref_diffs.py index 72051783fa7422..36d1c0e863635f 100644 --- a/tests/model_hub_tests/transformation_tests/generate_ref_diffs.py +++ b/tests/model_hub_tests/transformation_tests/generate_ref_diffs.py @@ -56,7 +56,7 @@ def get_models_list_type(file_name: str, cls: Union[Type[OVModelForCausalLM], Ty models.append((model_name, model_link, None, None, cls)) elif len(line_items) == 4: model_name, model_link, mark, reason = line_items - models.append((model_name, model_link, mark, reason)) + models.append((model_name, model_link, mark, reason, cls)) elif len(line_items) > 4: model_name, model_link, mark, reason, *other = line_items if not mark: @@ -106,7 +106,7 @@ def main(): # wrapping in try/catch block to continue printing models even if one has failed try: - paged_attention_transformation(ov_model, use_cache_eviction, use_cache_eviction) + paged_attention_transformation(ov_model, use_cache_eviction, use_cache_eviction, use_cache_eviction) except: print(f"Couldn't run SDPAToPA transformation on {model_id} and generate diffs.") continue @@ -117,10 +117,12 @@ def main(): after_map[op.get_type_name()] = after_map.get(op.get_type_name(), 0) + 1 print(f'\t"{model_id}" : {{', file=file) - for op in set(after_map.keys()) | set(before_map.keys()): + for op in sorted(set(after_map.keys()) | set(before_map.keys())): print(f'\t\t"{op}" : {after_map.get(op, 0) - before_map.get(op, 0)},', file=file) print('\t},', file=file) print('}', file=file) + print(f"output written to {OUTPUT_FILE}") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tests/model_hub_tests/transformation_tests/sdpa2pa_ref_diff.py b/tests/model_hub_tests/transformation_tests/sdpa2pa_ref_diff.py index 43ef49d9b5a226..aac6c3765aca3b 100644 --- a/tests/model_hub_tests/transformation_tests/sdpa2pa_ref_diff.py +++ b/tests/model_hub_tests/transformation_tests/sdpa2pa_ref_diff.py @@ -5,666 +5,665 @@ ref_diff_map = { "hf-internal-testing/tiny-random-LlamaForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-CohereForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-GPTJForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 13, "ReadValue" : -10, - "Assign" : -10, - }, - "hf-internal-testing/tiny-random-GPTNeoForCausalLM" : { - "PagedAttentionExtension" : 4, - "ScaledDotProductAttention" : -4, - "Parameter" : 11, - "ReadValue" : -8, - "Assign" : -8, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-GPTNeoXForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 13, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-MistralForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-CodeGenForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 13, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/Mixtral-tiny" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-GPTBigCodeForCausalLM" : { + "Assign" : -5, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 13, "ReadValue" : -5, - "Assign" : -5, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-Starcoder2ForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-BloomForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 14, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-gpt2" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 13, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-BlenderbotForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 8, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PegasusForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 8, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PhiForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-MptForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 14, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-StableLmForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PersimmonForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-FalconForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "hf-tiny-model-private/tiny-random-OPTForCausalLM" : { + "Assign" : -10, "PagedAttentionExtension" : 5, - "ScaledDotProductAttention" : -5, "Parameter" : 14, "ReadValue" : -10, - "Assign" : -10, + "ScaledDotProductAttention" : -5, }, "katuni4ka/tiny-random-xverse" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-baichuan2-13b" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-qwen" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-aquilachat" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-aquila2" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-qwen1.5-moe" : { + "Assign" : -8, "PagedAttentionExtension" : 4, - "ScaledDotProductAttention" : -4, "Parameter" : 11, "ReadValue" : -8, - "Assign" : -8, + "ScaledDotProductAttention" : -4, }, "katuni4ka/tiny-random-codegen2" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-olmo-hf" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-baichuan2" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-jais" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-internlm" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-internlm2" : { + "Assign" : -8, "PagedAttentionExtension" : 4, - "ScaledDotProductAttention" : -4, "Parameter" : 11, "ReadValue" : -8, + "ScaledDotProductAttention" : -4, + }, + "katuni4ka/tiny-random-minicpm" : { "Assign" : -8, + "PagedAttentionExtension" : 4, + "Parameter" : 11, + "ReadValue" : -8, + "ScaledDotProductAttention" : -4, }, - "katuni4ka/tiny-random-minicpm" : { - "ReadValue" : -8, - "ScaledDotProductAttention" : -4, - "Assign" : -8, - "PagedAttentionExtension" : 4, - "Parameter" : 11, - }, "katuni4ka/tiny-random-falcon-40b" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-dbrx" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "fxmarty/tiny-random-GemmaForCausalLM" : { + "Assign" : -2, "PagedAttentionExtension" : 1, - "ScaledDotProductAttention" : -1, "Parameter" : 5, "ReadValue" : -2, - "Assign" : -2, + "ScaledDotProductAttention" : -1, }, "fxmarty/tiny-dummy-qwen2" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "fxmarty/really-tiny-falcon-testing" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "Xenova/tiny-random-Phi3ForCausalLM" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "facebook/opt-125m" : { + "Assign" : -24, "PagedAttentionExtension" : 12, - "ScaledDotProductAttention" : -12, "Parameter" : 28, "ReadValue" : -24, - "Assign" : -24, + "ScaledDotProductAttention" : -12, }, "facebook/opt-350m" : { + "Assign" : -48, "PagedAttentionExtension" : 24, - "ScaledDotProductAttention" : -24, "Parameter" : 52, "ReadValue" : -48, - "Assign" : -48, + "ScaledDotProductAttention" : -24, }, "katuni4ka/tiny-random-chatglm2" : { + "Assign" : -4, "PagedAttentionExtension" : 2, - "ScaledDotProductAttention" : -2, "Parameter" : 7, "ReadValue" : -4, - "Assign" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-glm4" : { + "Assign" : -12, "PagedAttentionExtension" : 6, - "ScaledDotProductAttention" : -6, "Parameter" : 15, "ReadValue" : -12, - "Assign" : -12, + "ScaledDotProductAttention" : -6, }, "katuni4ka/tiny-random-llava-next" : { + "Assign" : -4, "PagedAttentionExtension" : 2, "Parameter" : 7, "ReadValue" : -4, "ScaledDotProductAttention" : -2, - "Assign" : -4, }, "katuni4ka/tiny-random-minicpmv-2_6" : { + "Assign" : -4, "PagedAttentionExtension" : 2, "Parameter" : 7, "ReadValue" : -4, "ScaledDotProductAttention" : -2, - "Assign" : -4, }, "katuni4ka/tiny-random-llava" : { "Assign" : -4, + "PagedAttentionExtension" : 2, "Parameter" : 7, "ReadValue" : -4, "ScaledDotProductAttention" : -2, - "PagedAttentionExtension" : 2, }, - # "katuni4ka/tiny-random-nanollava" : { # "Assign" : -4, + # "PagedAttentionExtension" : 2, # "Parameter" : 7, # "ReadValue" : -4, # "ScaledDotProductAttention" : -2, - # "PagedAttentionExtension" : 2, # }, + "hf-internal-testing/tiny-random-GPTNeoForCausalLM" : { + "ScaledDotProductAttention" : -4, + "ReadValue" : -8, + "PagedAttentionExtension" : 4, + "Parameter" : 11, + "Assign" : -8, + } } ref_diff_map_cache_eviction = { "hf-internal-testing/tiny-random-LlamaForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-CohereForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-GPTJForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 17, "Assign" : -10, - }, - "hf-internal-testing/tiny-random-GPTNeoForCausalLM" : { - "ScaledDotProductAttention" : -4, - "ReadValue" : -8, - "PagedAttentionExtension" : 4, - "Parameter" : 14, - "Assign" : -8, + "PagedAttentionExtension" : 5, + "Parameter" : 28, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-GPTNeoXForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 17, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 28, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-MistralForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-CodeGenForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 17, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 28, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/Mixtral-tiny" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-GPTBigCodeForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -5, - "PagedAttentionExtension" : 5, - "Parameter" : 17, "Assign" : -5, + "PagedAttentionExtension" : 5, + "Parameter" : 28, + "ReadValue" : -5, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-Starcoder2ForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-BloomForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 18, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 29, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-gpt2" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 17, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 28, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-BlenderbotForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 9, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 14, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PegasusForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 9, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 14, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PhiForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-MptForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 18, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 29, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "hf-internal-testing/tiny-random-StableLmForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-PersimmonForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-internal-testing/tiny-random-FalconForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "hf-tiny-model-private/tiny-random-OPTForCausalLM" : { - "ScaledDotProductAttention" : -5, - "ReadValue" : -10, - "PagedAttentionExtension" : 5, - "Parameter" : 18, "Assign" : -10, + "PagedAttentionExtension" : 5, + "Parameter" : 29, + "ReadValue" : -10, + "ScaledDotProductAttention" : -5, }, "katuni4ka/tiny-random-xverse" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-baichuan2-13b" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-qwen" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-aquilachat" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-aquila2" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-qwen1.5-moe" : { - "ScaledDotProductAttention" : -4, - "ReadValue" : -8, - "PagedAttentionExtension" : 4, - "Parameter" : 14, "Assign" : -8, + "PagedAttentionExtension" : 4, + "Parameter" : 23, + "ReadValue" : -8, + "ScaledDotProductAttention" : -4, }, "katuni4ka/tiny-random-codegen2" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-olmo-hf" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-baichuan2" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-jais" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-internlm" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-internlm2" : { - "ScaledDotProductAttention" : -4, - "ReadValue" : -8, - "PagedAttentionExtension" : 4, - "Parameter" : 14, "Assign" : -8, + "PagedAttentionExtension" : 4, + "Parameter" : 23, + "ReadValue" : -8, + "ScaledDotProductAttention" : -4, }, "katuni4ka/tiny-random-minicpm" : { - "ScaledDotProductAttention" : -4, - "Parameter" : 14, + "Assign" : -8, "PagedAttentionExtension" : 4, + "Parameter" : 23, "ReadValue" : -8, - "Assign" : -8, + "ScaledDotProductAttention" : -4, }, "katuni4ka/tiny-random-falcon-40b" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-dbrx" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "fxmarty/tiny-random-GemmaForCausalLM" : { - "ScaledDotProductAttention" : -1, - "ReadValue" : -2, - "PagedAttentionExtension" : 1, - "Parameter" : 5, "Assign" : -2, + "PagedAttentionExtension" : 1, + "Parameter" : 8, + "ReadValue" : -2, + "ScaledDotProductAttention" : -1, }, "fxmarty/tiny-dummy-qwen2" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "fxmarty/really-tiny-falcon-testing" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "Xenova/tiny-random-Phi3ForCausalLM" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "facebook/opt-125m" : { - "ScaledDotProductAttention" : -12, - "ReadValue" : -24, - "PagedAttentionExtension" : 12, - "Parameter" : 39, "Assign" : -24, + "PagedAttentionExtension" : 12, + "Parameter" : 64, + "ReadValue" : -24, + "ScaledDotProductAttention" : -12, }, "facebook/opt-350m" : { - "ScaledDotProductAttention" : -24, - "ReadValue" : -48, - "PagedAttentionExtension" : 24, - "Parameter" : 75, "Assign" : -48, + "PagedAttentionExtension" : 24, + "Parameter" : 124, + "ReadValue" : -48, + "ScaledDotProductAttention" : -24, }, "katuni4ka/tiny-random-chatglm2" : { - "ScaledDotProductAttention" : -2, - "ReadValue" : -4, - "PagedAttentionExtension" : 2, - "Parameter" : 8, "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, + "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-glm4" : { - "ScaledDotProductAttention" : -6, - "ReadValue" : -12, - "PagedAttentionExtension" : 6, - "Parameter" : 20, "Assign" : -12, + "PagedAttentionExtension" : 6, + "Parameter" : 33, + "ReadValue" : -12, + "ScaledDotProductAttention" : -6, }, "katuni4ka/tiny-random-llava-next" : { - "Parameter" : 8, "Assign" : -4, - "ReadValue" : -4, "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-minicpmv-2_6" : { - "Parameter" : 8, "Assign" : -4, - "ReadValue" : -4, "PagedAttentionExtension" : 2, + "Parameter" : 13, + "ReadValue" : -4, "ScaledDotProductAttention" : -2, }, "katuni4ka/tiny-random-llava" : { + "Assign" : -4, + "PagedAttentionExtension" : 2, + "Parameter" : 13, "ReadValue" : -4, - "Parameter" : 8, "ScaledDotProductAttention" : -2, - "PagedAttentionExtension" : 2, - "Assign" : -4, }, - - # "katuni4ka/tiny-random-nanollava" : { + # "katuni4ka/tiny-random-nanollava" : { + # "Assign" : -4, + # "PagedAttentionExtension" : 2, + # "Parameter" : 13, # "ReadValue" : -4, - # "Parameter" : 8, # "ScaledDotProductAttention" : -2, - # "PagedAttentionExtension" : 2, - # "Assign" : -4, # }, + + "hf-internal-testing/tiny-random-GPTNeoForCausalLM" : { + "ScaledDotProductAttention" : -4, + "ReadValue" : -8, + "PagedAttentionExtension" : 4, + "Parameter" : 23, + "Assign" : -8, + } } diff --git a/tests/model_hub_tests/transformation_tests/test_pa_transformation.py b/tests/model_hub_tests/transformation_tests/test_pa_transformation.py index 2bc6726dff030f..fc6e8c1e65903f 100644 --- a/tests/model_hub_tests/transformation_tests/test_pa_transformation.py +++ b/tests/model_hub_tests/transformation_tests/test_pa_transformation.py @@ -17,13 +17,14 @@ def compare_diffs(ov_model: ov.Model, model_id: str, use_block_indices_inputs: bool, - use_score_outputs: bool): + use_score_outputs: bool, + allow_cache_rotation: bool): before_map = {} for op in ov_model.get_ordered_ops(): if op.get_type_name() in nodes_to_compare: before_map[op.get_type_name()] = before_map.get(op.get_type_name(), 0) + 1 - paged_attention_transformation(ov_model, use_block_indices_inputs, use_score_outputs) + paged_attention_transformation(ov_model, use_block_indices_inputs, use_score_outputs, allow_cache_rotation) after_map = {} for op in ov_model.get_ordered_ops(): @@ -36,7 +37,7 @@ def compare_diffs(ov_model: ov.Model, for op in set(after_map.keys()) | set(before_map.keys()): resulting_map[op] = after_map.get(op, 0) - before_map.get(op, 0) - use_cache_eviction = use_block_indices_inputs and use_score_outputs + use_cache_eviction = use_block_indices_inputs and use_score_outputs and allow_cache_rotation reference_map = ref_diff_map_cache_eviction[model_id] if use_cache_eviction else ref_diff_map[model_id] assert reference_map == resulting_map @@ -51,32 +52,47 @@ def compare_diffs(ov_model: ov.Model, assert shape[-1].is_static, f"Dimension {len(shape) - 1} of input '{name}' in '{model_id}' is not static: {shape}" assert shape[-2].is_static, f"Dimension {len(shape) - 2} of input '{name}' in '{model_id}' is not static: {shape}" - # Test for block_indices inputs and scores outputs to appear in the model + interesting_input_patterns = {} + interesting_output_patterns = {} + + + # Test for block_indices inputs and scores outputs to appear in the model if (use_block_indices_inputs): - block_indices_pattern = r'block_indices\.[0-9]+' - block_indices_counter = 0 - - model_inputs = ov_model.inputs - for input in model_inputs: - for name in list(input.get_names()): - if re.search(block_indices_pattern, name): - block_indices_counter += 1 - - assert block_indices_counter == resulting_map["PagedAttentionExtension"], \ - f"The number of block_indices inputs doesn't correspond to the expected value. Expected {resulting_map['PagedAttentionExtension']}, received {block_indices_counter}" - + interesting_input_patterns["block_indices"] = r'^block_indices\.[0-9]+' + if (use_score_outputs): - score_pattern = r'scores\.[0-9]+' - score_outputs_counter = 0 + interesting_output_patterns["scores"] = r'^scores\.[0-9]+' + + if (allow_cache_rotation): + interesting_input_patterns["rotated_block_indices"] = r'^rotated_block_indices\.[0-9]+'; + interesting_input_patterns["rotation_deltas"] = r'^rotation_deltas\.[0-9]+'; + interesting_input_patterns["rotation_trig_lut"] = r'rotation_trig_lut'; + + input_counters = {k: 0 for k in interesting_input_patterns} + output_counters = {k: 0 for k in interesting_output_patterns} + + for pattern_dict, counter_dict, io_set in zip([interesting_input_patterns, interesting_output_patterns], + [input_counters, output_counters], + [ov_model.inputs, ov_model.outputs]): + for input_id in counter_dict: + pattern = pattern_dict[input_id] + for model_io in io_set: + for name in list(model_io.get_names()): + if re.search(pattern, name): + counter_dict[input_id] += 1 + + if allow_cache_rotation: + assert input_counters["rotation_trig_lut"] == 1 + input_counters.pop("rotation_trig_lut") + + for input_id, count in input_counters.items(): + assert count == resulting_map["PagedAttentionExtension"], \ + f"The number of {input_id} inputs doesn't correspond to the expected value. Expected {resulting_map['PagedAttentionExtension']}, received {count}" - model_outputs = ov_model.outputs - for output in model_outputs: - for name in list(output.get_names()): - if re.search(score_pattern, name): - score_outputs_counter += 1 + for output_id, count in output_counters.items(): + assert count == resulting_map["PagedAttentionExtension"], \ + f"The number of {output_id} outputs doesn't correspond to the expected value. Expected {resulting_map['PagedAttentionExtension']}, received {count}" - assert block_indices_counter == resulting_map["PagedAttentionExtension"], \ - f"The number of scores outputs doesn't correspond to the expected value. Expected {resulting_map['PagedAttentionExtension']}, received {block_indices_counter}" @retry(3, exceptions=(OSError,), delay=1) def run_pa(tmp_path, @@ -84,11 +100,12 @@ def run_pa(tmp_path, model_link, cls: Union[Type[OVModelForCausalLM], Type[OVModelForVisualCausalLM]], use_block_indices_inputs, - use_score_outputs): + use_score_outputs, + allow_cache_rotation): model = cls.from_pretrained(model_id, export=True, trust_remote_code=True) ov_model = model.model if cls is OVModelForCausalLM else model.lm_model - compare_diffs(ov_model, model_id, use_block_indices_inputs, use_score_outputs) + compare_diffs(ov_model, model_id, use_block_indices_inputs, use_score_outputs, allow_cache_rotation) @pytest.mark.precommit @pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit"))) @@ -99,7 +116,7 @@ def test_pa_precommit(tmp_path, model_name, model_link, mark, reason, ie_device) pytest.skip(reason) elif mark == 'xfail': pytest.xfail(reason) - run_pa(tmp_path, model_name, model_link, OVModelForCausalLM, False, False) + run_pa(tmp_path, model_name, model_link, OVModelForCausalLM, False, False, False) @pytest.mark.precommit @pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit"))) @@ -110,7 +127,7 @@ def test_pa_precommit_use_cache_eviction(tmp_path, model_name, model_link, mark, pytest.skip(reason) elif mark == 'xfail': pytest.xfail(reason) - run_pa(tmp_path, model_name, model_link, OVModelForCausalLM, True, True) + run_pa(tmp_path, model_name, model_link, OVModelForCausalLM, True, True, True) @pytest.mark.precommit @pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-vl-models-precommit"))) @@ -121,7 +138,7 @@ def test_pa_vlm(tmp_path, model_name, model_link, mark, reason, ie_device): pytest.skip(reason) elif mark == 'xfail': pytest.xfail(reason) - run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, False, False) + run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, False, False, False) @pytest.mark.precommit @pytest.mark.parametrize("model_name, model_link, mark, reason", utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-vl-models-precommit"))) @@ -132,4 +149,4 @@ def test_pa_vlm_use_cache_eviction(tmp_path, model_name, model_link, mark, reaso pytest.skip(reason) elif mark == 'xfail': pytest.xfail(reason) - run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, True, True) \ No newline at end of file + run_pa(tmp_path, model_name, model_link, OVModelForVisualCausalLM, True, True, True)