Skip to content

Commit 8a6b1ef

Browse files
[NPUW] Extend npuw::LLMInferRequest to support LM from VLMs (openvinotoolkit#29000)
Co-authored-by: Anastasiya Pronina <[email protected]>
1 parent 2a0f595 commit 8a6b1ef

File tree

3 files changed

+76
-27
lines changed

3 files changed

+76
-27
lines changed

src/plugins/intel_npu/src/plugin/npuw/llm_compiled_model.cpp

+12-9
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class TransposeValueTensors : public ov::pass::MatcherPass {
4141
auto matched_matmul = std::static_pointer_cast<ov::op::v0::MatMul>(node_matmul);
4242

4343
auto param_shape = matched_param->get_partial_shape();
44-
OPENVINO_ASSERT(param_shape.size() == 4u);
44+
NPUW_ASSERT(param_shape.size() == 4u);
4545
// NB: Transpose Parameter that correspond to V-tensor it will
4646
// speed-up its multiplication with attention scores
4747
std::swap(param_shape[2], param_shape[3]);
@@ -150,7 +150,7 @@ class TransposeValueTensors_llama3 : public TransposeValueTensors {
150150
auto matched_reshape = std::static_pointer_cast<ov::op::v1::Reshape>(matched_node_reshape);
151151

152152
auto shape_broadcast = matched_broadcast->get_output_shape(0);
153-
OPENVINO_ASSERT(shape_broadcast.size() == 5u);
153+
NPUW_ASSERT(shape_broadcast.size() == 5u);
154154
std::swap(shape_broadcast[3], shape_broadcast[4]);
155155

156156
LOG_DEBUG("shape_broadcast for: " << matched_broadcast->get_friendly_name()
@@ -162,7 +162,7 @@ class TransposeValueTensors_llama3 : public TransposeValueTensors {
162162
matched_broadcast->input(1).replace_source_output(broadcast_axes_node);
163163

164164
auto shape_reshape = matched_reshape->get_output_shape(0);
165-
OPENVINO_ASSERT(shape_reshape.size() == 4u);
165+
NPUW_ASSERT(shape_reshape.size() == 4u);
166166
std::swap(shape_reshape[2], shape_reshape[3]);
167167

168168
LOG_DEBUG("shape_reshape for: " << matched_reshape->get_friendly_name() << ", shape=" << shape_reshape);
@@ -371,6 +371,11 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
371371
ov::PartialShape new_shape;
372372
if (input_name.find("input_ids") != std::string::npos) {
373373
new_shape = ov::PartialShape({1, input_size});
374+
} else if (input_name.find("inputs_embeds") != std::string::npos) {
375+
// NB: VLMs case, model accepts inputs_embeds[BATCH, SEQ_LEN, EMB_SIZE]
376+
NPUW_ASSERT(input.get_partial_shape().size() == 3u);
377+
NPUW_ASSERT(input.get_partial_shape()[2].is_static());
378+
new_shape = ov::PartialShape({1, input_size, input.get_partial_shape()[2]});
374379
} else if (input_name.find("attention_mask") != std::string::npos) {
375380
new_shape = ov::PartialShape({1, kvcache_size});
376381
} else if (input_name.find("position_ids") != std::string::npos) {
@@ -628,14 +633,12 @@ ov::npuw::LLMCompiledModel::LLMCompiledModel(const std::shared_ptr<ov::Model>& m
628633

629634
m_kvcache_compiled = std::dynamic_pointer_cast<ov::npuw::CompiledModel>(
630635
ov::npuw::ICompiledModel::create(kvcache_model, plugin, generate_config));
631-
OPENVINO_ASSERT(m_kvcache_compiled,
632-
"Can't create ov::npuw::CompiledModel for passed kvcache "
633-
"model and its config, please check passed config.");
636+
NPUW_ASSERT(m_kvcache_compiled && "Can't create ov::npuw::CompiledModel for passed kvcache "
637+
"model and its config, please check passed config.");
634638
m_prefill_compiled = std::dynamic_pointer_cast<ov::npuw::CompiledModel>(
635639
ov::npuw::ICompiledModel::create(prefill_model, plugin, prefill_config));
636-
OPENVINO_ASSERT(m_prefill_compiled,
637-
"Can't create ov::npuw::CompiledModel for passed prefill "
638-
"model and its config, please check passed config.");
640+
NPUW_ASSERT(m_prefill_compiled && "Can't create ov::npuw::CompiledModel for passed prefill "
641+
"model and its config, please check passed config.");
639642

640643
implement_properties();
641644
LOG_DEBUG("Done");

src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.cpp

+60-16
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ void fill_tensor(ov::SoPtr<ov::ITensor> tensor, T fill_val, size_t offset = 0u)
1818
std::fill(tensor_data + offset, tensor_data + tensor->get_size(), fill_val);
1919
}
2020

21+
void fill_tensor_bytes(ov::SoPtr<ov::ITensor> tensor, uint8_t fill_val) {
22+
auto* tensor_data = reinterpret_cast<uint8_t*>(tensor->data());
23+
std::fill_n(tensor_data, tensor->get_byte_size(), fill_val);
24+
}
25+
2126
ov::SoPtr<ov::ITensor> make_tensor_slice(ov::SoPtr<ov::ITensor> tensor,
2227
uint32_t dim,
2328
uint32_t start_pos,
@@ -100,6 +105,20 @@ void copy_columns_by_row_chunks(ov::SoPtr<ov::ITensor> src, ov::SoPtr<ov::ITenso
100105
std::copy_n(src_p + src_offset, chunk_byte_size, dst_p + dst_offset);
101106
}
102107
}
108+
109+
std::optional<ov::Output<const ov::Node>> find_port_by_name(const std::vector<ov::Output<const ov::Node>>& ports,
110+
const std::string& name) {
111+
auto it = std::find_if(ports.begin(), ports.end(), [&](const auto& port) {
112+
return port.get_names().count(name) != 0;
113+
});
114+
if (it == ports.end()) {
115+
return std::nullopt;
116+
}
117+
return std::make_optional(*it);
118+
}
119+
120+
constexpr uint32_t INPUT_IDS_SEQ_LEN_DIM = 1;
121+
103122
} // anonymous namespace
104123

105124
ov::npuw::LLMInferRequest::LLMInferRequest(const std::shared_ptr<ov::npuw::LLMCompiledModel>& compiled_model)
@@ -112,6 +131,14 @@ ov::npuw::LLMInferRequest::LLMInferRequest(const std::shared_ptr<ov::npuw::LLMCo
112131
init_tensor(output_port);
113132
}
114133

134+
auto input_ids_port = find_port_by_name(compiled_model->m_prefill_compiled->inputs(), "input_ids");
135+
if (input_ids_port.has_value()) {
136+
m_input_ids_name = "input_ids";
137+
} else {
138+
OPENVINO_ASSERT(find_port_by_name(compiled_model->m_prefill_compiled->inputs(), "inputs_embeds").has_value());
139+
m_input_ids_name = "inputs_embeds";
140+
}
141+
115142
m_kvcache_request = compiled_model->m_kvcache_compiled->create_infer_request();
116143
m_prefill_request = compiled_model->m_prefill_compiled->create_infer_request();
117144

@@ -152,7 +179,7 @@ void ov::npuw::LLMInferRequest::init_tensor(const ov::Output<const ov::Node>& po
152179
}
153180

154181
void ov::npuw::LLMInferRequest::prepare_for_new_conversation() {
155-
fill_tensor<int64_t>(m_prefill_request->get_tensor(m_prefill_in_ports.at("input_ids")), 0);
182+
fill_tensor_bytes(m_prefill_request->get_tensor(m_prefill_in_ports.at(m_input_ids_name)), 0u);
156183
fill_tensor<int64_t>(m_prefill_request->get_tensor(m_prefill_in_ports.at("attention_mask")), 0);
157184
fill_tensor<int64_t>(m_prefill_request->get_tensor(m_prefill_in_ports.at("position_ids")), 0);
158185
fill_tensor<int64_t>(m_kvcache_request->get_tensor(m_kvcache_in_ports.at("attention_mask")), 0);
@@ -167,20 +194,29 @@ void ov::npuw::LLMInferRequest::infer_prefill(ov::SoPtr<ov::ITensor> input_ids,
167194

168195
prepare_for_new_conversation();
169196

170-
auto padded_input_ids = m_prefill_request->get_tensor(m_prefill_in_ports.at("input_ids"));
171-
const size_t offset = padded_input_ids->get_size() - input_ids->get_size();
172-
std::copy_n(input_ids->data<int64_t>(), input_ids->get_size(), padded_input_ids->data<int64_t>() + offset);
197+
auto padded_input = m_prefill_request->get_tensor(m_prefill_in_ports.at(m_input_ids_name));
198+
// NB: padded_input can be either fp32(VLM) or i64(LLM)
199+
std::copy_n(
200+
reinterpret_cast<uint8_t*>(input_ids->data()),
201+
input_ids->get_byte_size(),
202+
reinterpret_cast<uint8_t*>(padded_input->data()) + padded_input->get_byte_size() - input_ids->get_byte_size());
173203

174204
auto padded_attention_mask = m_prefill_request->get_tensor(m_prefill_in_ports.at("attention_mask"));
175-
std::copy_n(attention_mask->data<int64_t>(),
176-
attention_mask->get_size(),
177-
padded_attention_mask->data<int64_t>() + offset);
205+
std::copy_n(
206+
attention_mask->data<int64_t>(),
207+
attention_mask->get_size(),
208+
padded_attention_mask->data<int64_t>() + padded_attention_mask->get_size() - attention_mask->get_size());
178209

179210
auto padded_position_ids = m_prefill_request->get_tensor(m_prefill_in_ports.at("position_ids"));
180-
std::copy_n(position_ids->data<int64_t>(), position_ids->get_size(), padded_position_ids->data<int64_t>() + offset);
211+
212+
std::copy_n(position_ids->data<int64_t>(),
213+
position_ids->get_size(),
214+
padded_position_ids->data<int64_t>() + padded_position_ids->get_size() - position_ids->get_size());
181215

182216
m_prefill_request->infer();
183-
m_npuw_llm_compiled_model->m_kvcache_desc.num_stored_tokens += static_cast<uint32_t>(input_ids->get_size());
217+
218+
m_npuw_llm_compiled_model->m_kvcache_desc.num_stored_tokens +=
219+
static_cast<uint32_t>(input_ids->get_shape()[INPUT_IDS_SEQ_LEN_DIM]);
184220
m_need_copy_kvcache = true;
185221

186222
m_logits = m_prefill_request->get_tensor(m_prefill_out_ports.at("logits"));
@@ -244,8 +280,11 @@ void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr<ov::ITensor> input_ids,
244280
}
245281

246282
// FIXME: these tensors should be shared between the parent & child models
247-
auto kv_input_ids = m_kvcache_request->get_tensor(m_kvcache_in_ports.at("input_ids"));
248-
std::copy_n(input_ids->data<int64_t>(), input_ids->get_size(), kv_input_ids->data<int64_t>());
283+
auto kv_input_ids = m_kvcache_request->get_tensor(m_kvcache_in_ports.at(m_input_ids_name));
284+
// NB: input_ids can be either fp32(VLM) or i64(LLM)
285+
std::copy_n(reinterpret_cast<uint8_t*>(input_ids->data()),
286+
input_ids->get_byte_size(),
287+
reinterpret_cast<uint8_t*>(kv_input_ids->data()));
249288

250289
auto kv_attn_mask = m_kvcache_request->get_tensor(m_kvcache_in_ports.at("attention_mask"));
251290
std::copy_n(attention_mask->data<int64_t>(), attention_mask->get_size() - 1, kv_attn_mask->data<int64_t>());
@@ -290,15 +329,20 @@ void ov::npuw::LLMInferRequest::infer_generate(ov::SoPtr<ov::ITensor> input_ids,
290329
void ov::npuw::LLMInferRequest::infer() {
291330
const auto& inputs = get_inputs();
292331

293-
auto input_ids = get_tensor(inputs[0]);
294-
auto attention_mask = get_tensor(inputs[1]);
295-
auto position_ids = get_tensor(inputs[2]);
332+
auto input_ids = get_tensor(find_port_by_name(inputs, m_input_ids_name).value());
333+
auto attention_mask = get_tensor(find_port_by_name(inputs, "attention_mask").value());
334+
// FIXME: position_ids might be optional for some models!
335+
auto position_ids = get_tensor(find_port_by_name(inputs, "position_ids").value());
296336

297-
OPENVINO_ASSERT(ov::element::i64 == input_ids->get_element_type());
337+
// NB: For VLM, the "inputs_embeds" contains float values (embeddings)
338+
OPENVINO_ASSERT(ov::element::f32 == input_ids->get_element_type() ||
339+
ov::element::i64 == input_ids->get_element_type());
298340
OPENVINO_ASSERT(ov::element::i64 == attention_mask->get_element_type());
299341
OPENVINO_ASSERT(ov::element::i64 == position_ids->get_element_type());
300342

301-
if (input_ids->get_size() != 1) {
343+
// NB: Check the sequence length provided for input_ids
344+
// in order to distinguish prefill / generate stages
345+
if (input_ids->get_shape()[INPUT_IDS_SEQ_LEN_DIM] != 1) {
302346
infer_prefill(input_ids, attention_mask, position_ids);
303347
} else {
304348
infer_generate(input_ids, attention_mask, position_ids);

src/plugins/intel_npu/src/plugin/npuw/llm_infer_request.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@ class LLMInferRequest final : public ov::ISyncInferRequest {
3131
}
3232

3333
private:
34-
void init_tensor(const ov::Output<const ov::Node>& port);
35-
3634
void prepare_for_new_conversation();
35+
void init_tensor(const ov::Output<const ov::Node>& port);
3736

3837
void infer_prefill(ov::SoPtr<ov::ITensor> input_ids,
3938
ov::SoPtr<ov::ITensor> attention_mask,
@@ -53,6 +52,9 @@ class LLMInferRequest final : public ov::ISyncInferRequest {
5352
std::unordered_map<std::string, ov::Output<const ov::Node>> m_prefill_out_ports;
5453
std::unordered_map<std::string, ov::Output<const ov::Node>> m_kvcache_in_ports;
5554
std::unordered_map<std::string, ov::Output<const ov::Node>> m_kvcache_out_ports;
55+
56+
// NB: It can be either input_ids(LLM) or inputs_embeds(VLM)
57+
std::string m_input_ids_name;
5658
};
5759

5860
} // namespace npuw

0 commit comments

Comments
 (0)