Skip to content

Commit

Permalink
Add cache rotation inputs and CPU kernel implementation for cache rot…
Browse files Browse the repository at this point in the history
…ation (openvinotoolkit#27088)

Tickets:
153783
  • Loading branch information
vshampor authored and MirceaDan99 committed Jan 22, 2025
1 parent d33131f commit 5994651
Show file tree
Hide file tree
Showing 28 changed files with 1,694 additions and 396 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/job_cxx_unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,12 @@ jobs:
${{ env.SETUPVARS_COMMAND }}
${{ env.INSTALL_TEST_DIR }}/ov_cpu_unit_tests --gtest_print_time=1 --gtest_output=xml:${{ env.INSTALL_TEST_DIR }}/TEST-CPUUnitTests.xml
- name: CPU plugin unit tests (vectorized)
if: fromJSON(inputs.affected-components).CPU.test
run: |
${{ env.SETUPVARS_COMMAND }}
${{ env.INSTALL_TEST_DIR }}/ov_cpu_unit_tests_vectorized --gtest_print_time=1 --gtest_output=xml:${{ env.INSTALL_TEST_DIR }}/TEST-CPUUnitTestsVectorized.xml
- name: ov_subgraphs_dumper_tests tests
run: |
${{ env.SETUPVARS_COMMAND }}
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,4 @@ endif()
# provides a callback function to describe each component in repo
include(cmake/packaging/packaging.cmake)

ov_cpack(${OV_CPACK_COMPONENTS_ALL})
ov_cpack(${OV_CPACK_COMPONENTS_ALL})
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,18 @@ void regmodule_offline_transformations(py::module m) {

m_offline_transformations.def(
"paged_attention_transformation",
[](py::object& ie_api_model, bool use_block_indices_inputs, bool use_score_outputs) {
[](py::object& ie_api_model, bool use_block_indices_inputs, bool use_score_outputs, bool allow_cache_rotation) {
const auto model = Common::utils::convert_to_model(ie_api_model);
ov::pass::Manager manager;
manager.register_pass<ov::pass::SDPAToPagedAttention>(use_block_indices_inputs, use_score_outputs);
manager.register_pass<ov::pass::SDPAToPagedAttention>(use_block_indices_inputs,
use_score_outputs,
allow_cache_rotation);
manager.run_passes(model);
},
py::arg("model"),
py::arg("use_block_indices_inputs") = false,
py::arg("use_score_outputs") = false);
py::arg("use_score_outputs") = false,
py::arg("allow_cache_rotation") = false);

m_offline_transformations.def(
"stateful_to_stateless_transformation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@ class ov::pass::StateManagementPattern : public ov::pass::MatcherPass {
ParameterVector& parameters_to_remove,
int& layer_index,
ov::Output<Node> max_context_len,
ParameterVector& block_indices_inputs,
ParameterVector& block_indices_inputs_for_each_layer,
ResultVector& score_results,
bool use_block_indices,
bool use_score_outputs);
bool use_per_layer_block_indices_inputs,
bool use_score_outputs,
bool allow_cache_rotation,
ParameterVector& rotated_block_indices_inputs_for_each_layer,
ParameterVector& rotation_deltas_inputs_for_each_layer,
std::shared_ptr<op::v0::Parameter> model_rotation_trig_lut);
};
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/paged_attention.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/select.hpp"
Expand Down Expand Up @@ -70,10 +71,14 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
ParameterVector& parameters_to_remove,
int& layer_index,
Output<Node> max_context_len,
ParameterVector& block_indices_inputs,
ParameterVector& block_indices_inputs_for_each_layer,
ResultVector& score_results,
bool use_block_indices_inputs,
bool use_score_outputs) {
bool use_per_layer_block_indices_inputs,
bool use_score_outputs,
bool allow_cache_rotation,
ParameterVector& rotated_block_indices_inputs_for_each_layer,
ParameterVector& rotation_deltas_inputs_for_each_layer,
std::shared_ptr<op::v0::Parameter> model_rotation_trig_lut) {
MATCHER_SCOPE(StateManagementPattern);

auto k_current = pattern::any_input();
Expand Down Expand Up @@ -176,9 +181,11 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
&model_remaining_params,
&sliding_window,
&parameters_to_remove,
&block_indices_inputs,
&block_indices_inputs_for_each_layer,
&score_results,
&layer_index](ov::pass::pattern::Matcher& m) {
&layer_index,
&rotated_block_indices_inputs_for_each_layer,
&rotation_deltas_inputs_for_each_layer](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto real_q = pattern_map.at(q);

Expand Down Expand Up @@ -382,11 +389,27 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par
max_context_len.get_node_shared_ptr()};
pa_arguments.insert(pa_arguments.end(), additional_params.begin(), additional_params.end());

if (use_block_indices_inputs) {
if (use_per_layer_block_indices_inputs) {
auto block_indices = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}),
"block_indices." + std::to_string(layer_index - 1));
pa_arguments.insert(pa_arguments.begin() + 7, block_indices);
block_indices_inputs.push_back(block_indices);
block_indices_inputs_for_each_layer.push_back(block_indices);
}

OPENVINO_ASSERT(pa_arguments.size() == 13);

if (allow_cache_rotation) {
auto rotated_block_indices = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}),
"rotated_block_indices." + std::to_string(layer_index - 1));
auto rotation_deltas = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1, -1}),
"rotation_deltas." + std::to_string(layer_index - 1));

pa_arguments.insert(pa_arguments.begin() + 13, rotated_block_indices);
pa_arguments.insert(pa_arguments.begin() + 14, rotation_deltas);
pa_arguments.insert(pa_arguments.begin() + 15, model_rotation_trig_lut);

rotated_block_indices_inputs_for_each_layer.push_back(rotated_block_indices);
rotation_deltas_inputs_for_each_layer.push_back(rotation_deltas);
}

auto paged_attention = std::make_shared<ov::op::PagedAttentionExtension>(pa_arguments);
Expand Down Expand Up @@ -444,4 +467,4 @@ ov::pass::StateManagementPattern::StateManagementPattern(ParameterVector& kv_par

auto m = std::make_shared<ov::pass::pattern::Matcher>(sdpa_variants, matcher_name);
register_matcher(m, callback);
}
}
7 changes: 5 additions & 2 deletions src/core/include/openvino/pass/sdpa_to_paged_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ class OPENVINO_API SDPAToPagedAttention : public ModelPass {
public:
OPENVINO_MODEL_PASS_RTTI("SDPAToPagedAttention");

explicit SDPAToPagedAttention(bool use_block_indices_inputs = false, bool use_score_outputs = false);
explicit SDPAToPagedAttention(bool use_per_layer_block_indices_inputs = false,
bool use_score_outputs = false,
bool allow_cache_rotation = false);
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;

private:
bool m_use_block_indices_inputs;
bool m_use_per_layer_block_indices_inputs;
bool m_use_score_outputs;
bool m_allow_cache_rotation;
};
} // namespace pass
} // namespace ov
40 changes: 38 additions & 2 deletions src/core/src/op/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ void PagedAttentionExtension::validate_and_infer_types() {
OV_OP_SCOPE(PagedAttentionExtension_validate_and_infer_types);

NODE_VALIDATION_CHECK(this,
get_input_size() == 13,
"PagedAttensionExtension expects 13 inputs, but it has ",
get_input_size() == 13 || get_input_size() == 16,
"PagedAttensionExtension expects 13 or 16 inputs, but it has ",
get_input_size());

NODE_VALIDATION_CHECK(
Expand Down Expand Up @@ -147,6 +147,42 @@ void PagedAttentionExtension::validate_and_infer_types() {
get_input_element_type(12),
".");

if (get_input_size() == 16) {
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(13).rank().is_dynamic() || get_input_partial_shape(13).rank().get_length() == 1,
"Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank ",
get_input_partial_shape(13).rank().get_length(),
".");
NODE_VALIDATION_CHECK(this,
get_input_element_type(13).is_dynamic() || get_input_element_type(13) == element::i32,
"Element type of `rotated_block_indices` input should be i32, but it is ",
get_input_element_type(13),
".");
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(14).rank().is_dynamic() || get_input_partial_shape(14).rank().get_length() == 2,
"Input `rotation_deltas` should either have rank 2 or be omitted, but it has rank ",
get_input_partial_shape(14).rank().get_length(),
".");
NODE_VALIDATION_CHECK(this,
get_input_element_type(14).is_dynamic() || get_input_element_type(14) == element::i32,
"Element type of `rotation_deltas` input should be i32, but it is ",
get_input_element_type(14),
".");
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(15).rank().is_dynamic() || get_input_partial_shape(15).rank().get_length() == 2,
"Input `rotation_trig_lut` should either have rank 2 or be omitted, but it has rank ",
get_input_partial_shape(15).rank().get_length(),
".");
NODE_VALIDATION_CHECK(this,
get_input_element_type(15).is_dynamic() || get_input_element_type(15) == element::f32,
"Element type of `rotation_trig_lut` input should be f32, but it is ",
get_input_element_type(15),
".");
}

// value head_size may be not same with key
auto out_ps = get_input_partial_shape(0);
const auto& key_ps = get_input_partial_shape(1);
Expand Down
44 changes: 33 additions & 11 deletions src/core/src/pass/sdpa_to_paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

using namespace ov::op;

ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_block_indices_inputs, bool use_score_outputs)
: m_use_block_indices_inputs(use_block_indices_inputs),
m_use_score_outputs(use_score_outputs) {}
ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_per_layer_block_indices_inputs,
bool use_score_outputs,
bool allow_cache_rotation)
: m_use_per_layer_block_indices_inputs(use_per_layer_block_indices_inputs),
m_use_score_outputs(use_score_outputs),
m_allow_cache_rotation(allow_cache_rotation) {}

static std::shared_ptr<v0::Parameter> setName(std::shared_ptr<v0::Parameter> node, const char* name) {
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
Expand All @@ -46,11 +49,18 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), "subsequence_begins"),
setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), "block_indices_begins"),
};
if (!m_use_block_indices_inputs) {
if (!m_use_per_layer_block_indices_inputs) {
auto block_indices = setName(std::make_shared<v0::Parameter>(element::i32, PartialShape{-1}), "block_indices");
model_remaining_params.insert(model_remaining_params.begin() + 2, block_indices);
}

std::shared_ptr<v0::Parameter> model_rotation_trig_lut;

if (m_allow_cache_rotation) {
model_rotation_trig_lut =
setName(std::make_shared<v0::Parameter>(element::f32, PartialShape{-1, -1}), "rotation_trig_lut");
}

auto sliding_window = v0::Constant::create(element::i32, Shape{}, {0}); // sliding_window

auto get_parameter = [=](const std::shared_ptr<ov::Model>& model,
Expand Down Expand Up @@ -91,7 +101,10 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
ParameterVector kv_parameters;
ParameterVector parameters_to_remove;
ResultVector results_to_remove; // # used, but cannot really track all Results in stateless model
ParameterVector block_indices_inputs;
ParameterVector block_indices_inputs_for_each_layer;
ParameterVector rotated_block_indices_inputs_for_each_layer;
ParameterVector rotation_deltas_inputs_for_each_layer;

ResultVector score_results;

std::shared_ptr<v0::Parameter> position_ids;
Expand Down Expand Up @@ -120,11 +133,14 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
parameters_to_remove,
layer_index,
max_context_len->output(0),
block_indices_inputs,
block_indices_inputs_for_each_layer,
score_results,
m_use_block_indices_inputs,
m_use_score_outputs);

m_use_per_layer_block_indices_inputs,
m_use_score_outputs,
m_allow_cache_rotation,
rotated_block_indices_inputs_for_each_layer,
rotation_deltas_inputs_for_each_layer,
model_rotation_trig_lut);
manager.register_pass<PrevSequenceLengthPattern>(unsqueezed_input_ids, max_context_len, position_ids);
manager.register_pass<TotalSequenceLengthPattern>(max_context_len);
manager.register_pass<TotalSequenceLengthPatternQwen>(max_context_len);
Expand Down Expand Up @@ -174,14 +190,20 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
model->remove_parameter(parameter);
}

if (m_use_block_indices_inputs) {
model->add_parameters(block_indices_inputs);
if (m_use_per_layer_block_indices_inputs) {
model->add_parameters(block_indices_inputs_for_each_layer);
}

if (m_use_score_outputs) {
model->add_results(score_results);
}

if (m_allow_cache_rotation) {
model->add_parameters(rotated_block_indices_inputs_for_each_layer);
model->add_parameters(rotation_deltas_inputs_for_each_layer);
model->add_parameters({model_rotation_trig_lut});
}

model->add_parameters(kv_parameters);
model->add_parameters(model_remaining_params);
model->add_parameters({std::move(max_context_len)});
Expand Down
Loading

0 comments on commit 5994651

Please sign in to comment.