Skip to content

Commit

Permalink
Merge pull request vllm-project#28 from luo-cheng2021/luocheng/pa-kv-u8
Browse files Browse the repository at this point in the history
[CPU] PagedAttention support u8 kvcache
  • Loading branch information
ilya-lavrenov authored Apr 17, 2024
2 parents 3570043 + 14ea134 commit 307a6d1
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 0 deletions.
5 changes: 5 additions & 0 deletions use_with_openvino.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,8 @@ To pass the variable in docker, use `-e VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=1

The variable enables weights compression logic described in [optimum-intel 8-bit weights quantization](https://huggingface.co/docs/optimum/intel/optimization_ov#8-bit).
Hence, even if the variable is enabled, the compression is applied only for models starting with a certain size and avoids compression of too small models due to a significant accuracy drop.

## Use UInt-8 KV cache Compression

KV cache uint-8 compression is disabled by default. For better performance and lesser memory consumption, the KV cache compression can be enabled by setting the environment variable `VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`.
To pass the variable in docker, use `-e VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8` as an additional argument to `docker run` command in the examples above.
9 changes: 9 additions & 0 deletions vllm/executor/openvino_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, List, Optional, Tuple, Set
import torch.distributed
import gc
import os

from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
Expand Down Expand Up @@ -57,6 +58,9 @@ def __init__(
self.device_config = device_config

self.head_size = model_config.get_head_size()
if device_config.device.type == "cpu":
if cache_config.cache_dtype == "u8":
self.head_size += 8
self.num_layers = model_config.get_num_layers(parallel_config)
self.num_heads = model_config.get_num_kv_heads(parallel_config)

Expand Down Expand Up @@ -185,6 +189,9 @@ def get_cache_block_size(
device_config: DeviceConfig,
) -> int:
head_size = model_config.get_head_size()
if device_config.device.type == "cpu":
if cache_dtype == "u8":
head_size += 8
num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)

Expand Down Expand Up @@ -455,6 +462,8 @@ def __init__(
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
) -> None:
if os.environ.get("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION", "") == "u8":
cache_config.cache_dtype = "u8"
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/openvino_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def _patch_model_with_openvino(
Type.f32: torch.float32,
Type.f16: torch.float16,
Type.bf16: torch.bfloat16,
Type.u8: torch.uint8,
Type.i32: torch.int32,
Type.i64: torch.int64
}
Expand Down Expand Up @@ -521,6 +522,7 @@ def paged_attention_convertion(context):
torch.float32: Type.f32,
torch.float16: Type.f16,
torch.bfloat16: Type.bf16,
torch.uint8: Type.u8,
torch.int32: Type.i32,
torch.int64: Type.i64
}
Expand Down
1 change: 1 addition & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8_e5m2": torch.uint8,
"u8": torch.uint8
}


Expand Down

0 comments on commit 307a6d1

Please sign in to comment.