Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix OpenVINO model runner #12750

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/attention/backends/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,7 @@ class OpenVINOAttentionMetadata:
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]

# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool
11 changes: 5 additions & 6 deletions vllm/model_executor/model_loader/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import DeviceConfig, ModelConfig
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states)
Expand Down Expand Up @@ -103,7 +103,6 @@ def __init__(
self,
ov_core: ov.Core,
model_config: ModelConfig,
device_config: DeviceConfig,
kv_cache_dtype: ov.Type,
) -> None:
super().__init__()
Expand Down Expand Up @@ -187,8 +186,7 @@ def sample(


def get_model(
model_config: ModelConfig,
device_config: DeviceConfig,
vllm_config: VllmConfig,
kv_cache_dtype: ov.Type,
**kwargs,
) -> torch.nn.Module:
Expand All @@ -201,5 +199,6 @@ def get_model(
"be added in the future. If this is important to you, "
"please open an issue on github.")

return OpenVINOCausalLM(ov_core, model_config, device_config,
kv_cache_dtype)
with set_current_vllm_config(vllm_config):
return OpenVINOCausalLM(ov_core, vllm_config.model_config,
kv_cache_dtype)
9 changes: 3 additions & 6 deletions vllm/worker/openvino_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,13 @@ def __init__(
):
self.ov_core = ov_core
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
cache_config = self.cache_config
model_config = self.model_config
self.is_driver_worker = is_driver_worker

self.device = self.device_config.device

self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.sliding_window = self.model_config.get_sliding_window()
self.block_size = self.cache_config.block_size

self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
Expand All @@ -81,8 +79,7 @@ def __init__(
self.model: nn.Module # Set after init_Model

def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
self.model = get_model(vllm_config=self.vllm_config,
kv_cache_dtype=self.kv_cache_dtype,
ov_core=self.ov_core)

Expand Down