Skip to content

Commit

Permalink
Fix OpenVINO device
Browse files Browse the repository at this point in the history
Signed-off-by: Harry Mellor <[email protected]>
  • Loading branch information
hmellor committed Feb 4, 2025
1 parent 18016a5 commit 231a230
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
4 changes: 4 additions & 0 deletions vllm/attention/backends/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,7 @@ class OpenVINOAttentionMetadata:
# `model_executable`.
multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]]

# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool
11 changes: 5 additions & 6 deletions vllm/model_executor/model_loader/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import vllm.envs as envs
from vllm.attention.backends.openvino import OpenVINOAttentionMetadata
from vllm.config import DeviceConfig, ModelConfig
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import (LogitsProcessor,
_prune_hidden_states)
Expand Down Expand Up @@ -103,7 +103,6 @@ def __init__(
self,
ov_core: ov.Core,
model_config: ModelConfig,
device_config: DeviceConfig,
kv_cache_dtype: ov.Type,
) -> None:
super().__init__()
Expand Down Expand Up @@ -187,8 +186,7 @@ def sample(


def get_model(
model_config: ModelConfig,
device_config: DeviceConfig,
vllm_config: VllmConfig,
kv_cache_dtype: ov.Type,
**kwargs,
) -> torch.nn.Module:
Expand All @@ -201,5 +199,6 @@ def get_model(
"be added in the future. If this is important to you, "
"please open an issue on github.")

return OpenVINOCausalLM(ov_core, model_config, device_config,
kv_cache_dtype)
with set_current_vllm_config(vllm_config):
return OpenVINOCausalLM(ov_core, vllm_config.model_config,
kv_cache_dtype)
9 changes: 3 additions & 6 deletions vllm/worker/openvino_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,13 @@ def __init__(
):
self.ov_core = ov_core
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
cache_config = self.cache_config
model_config = self.model_config
self.is_driver_worker = is_driver_worker

self.device = self.device_config.device

self.kv_cache_dtype = kv_cache_dtype
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.sliding_window = self.model_config.get_sliding_window()
self.block_size = self.cache_config.block_size

self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
Expand All @@ -81,8 +79,7 @@ def __init__(
self.model: nn.Module # Set after init_Model

def load_model(self) -> None:
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
self.model = get_model(vllm_config=self.vllm_config,
kv_cache_dtype=self.kv_cache_dtype,
ov_core=self.ov_core)

Expand Down

0 comments on commit 231a230

Please sign in to comment.