Skip to content

Commit

Permalink
Remove debug prints, use per-block rotation
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Dec 20, 2024
1 parent 9bf9aa8 commit 92b1310
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 10 deletions.
10 changes: 3 additions & 7 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,14 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::init(
/* collect_attention_scores = */ true,
/* is_use_per_layer_cache_control = */ true);
m_rotation_deltas_stores.reserve(m_num_decoder_layers);
ov::Shape rotation_deltas_store_shape{scheduler_config.num_kv_blocks, m_scheduler->get_block_size()}; // last dim can be later changed to BLOCK_SIZE for per-token granularity
std::cout << "VSHAMPOR: memsetting and pushing delta stores" << std::endl;
ov::Shape rotation_deltas_store_shape{scheduler_config.num_kv_blocks, 1}; // last dim can be later changed to BLOCK_SIZE for per-token granularity
for (size_t i = 0; i < m_num_decoder_layers; i++) {
ov::Tensor store(ov::element::i32, rotation_deltas_store_shape);
std::memset(store.data(), 0, store.get_byte_size());
m_rotation_deltas_stores.push_back(store);
}

// const auto& eviction_config = m_scheduler->get_config().cache_eviction_config;
// size_t max_sequence_cache_occupation_length_in_blocks = (eviction_config.get_evictable_size() + eviction_config.get_recent_size()) / m_scheduler->get_block_size() + 1;
size_t max_sequence_cache_occupation_length_in_blocks = scheduler_config.max_num_batched_tokens + 1;
std::cout << "VSHAMPOR: max_sequence_occupation_length is " << max_sequence_cache_occupation_length_in_blocks << std::endl;
size_t embedding_size = device_config.get_head_size();
m_cache_rotation_calculator = std::make_shared<CacheRotationCalculator>(
m_scheduler->get_block_size(),
Expand Down Expand Up @@ -489,7 +485,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_compute_cache_rotation
m_current_step_rotated_block_indices_per_sequence[layer_idx][seq_id].push_back(
block_rotation_data.logical_block_idx);

size_t block_offset = num_blocks_to_rotate_for_each_layer[layer_idx] * m_scheduler->get_block_size();
size_t block_offset = num_blocks_to_rotate_for_each_layer[layer_idx];
auto rotation_deltas_tensor_data =
m_rotation_deltas_stores[layer_idx].data<int32_t>() + block_offset;
for (size_t tok_idx = 0; tok_idx < m_scheduler->get_block_size(); tok_idx++) {
Expand All @@ -504,7 +500,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_compute_cache_rotation
m_current_step_rotation_deltas.emplace_back(
m_rotation_deltas_stores[i],
ov::Coordinate{0, 0},
ov::Coordinate{num_blocks_to_rotate_for_each_layer[i], m_scheduler->get_block_size()});
ov::Coordinate{num_blocks_to_rotate_for_each_layer[i], 1});
}
}

Expand Down
1 change: 0 additions & 1 deletion src/cpp/src/model_runner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ class ModelRunner {
{
static ManualTimer timer("pure generate inference");
timer.start();
std::cout << "VSHAMPOR: m_request.infer() " << std::endl;
m_request.infer();
timer.end();
}
Expand Down
2 changes: 0 additions & 2 deletions src/cpp/src/utils/paged_attention_transformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ void apply_paged_attention_transformations(std::shared_ptr<ov::Model> model, boo
bool use_block_indices_inputs = per_layer_cache_control;
bool use_score_outputs = per_layer_cache_control;
bool allow_cache_rotation = per_layer_cache_control;
std::cout << "VSHAMPOR: applying PA transforms, flags are " << use_block_indices_inputs << use_score_outputs << allow_cache_rotation << std::endl;
ov::pass::SDPAToPagedAttention(use_block_indices_inputs, use_score_outputs, allow_cache_rotation)
.run_on_model(model);
std::cout << "VSHAMPOR: transform done " << std::endl;
}

void set_kv_cache_type_and_shape(std::shared_ptr<ov::Model> model, DeviceConfig& device_config) {
Expand Down

0 comments on commit 92b1310

Please sign in to comment.