Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware] [Intel] Enable Multiprocessing and tensor parallel in CPU backend and update documentation #6125

Merged
merged 36 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
be05b47
Adapt MP to cpu_executor
bigPYJ1151 Jul 3, 2024
be11ea5
Add CI
bigPYJ1151 Jul 3, 2024
3b56f13
update
bigPYJ1151 Jul 3, 2024
e1fd267
update
bigPYJ1151 Jul 3, 2024
c0ab019
update doc
bigPYJ1151 Jul 3, 2024
3d653f6
format
bigPYJ1151 Jul 3, 2024
84d8029
update
bigPYJ1151 Jul 3, 2024
c7fee29
Fix
bigPYJ1151 Jul 4, 2024
0d85c41
trigger
bigPYJ1151 Jul 4, 2024
991e8b3
trigger
bigPYJ1151 Jul 4, 2024
1845d95
trigger
bigPYJ1151 Jul 4, 2024
766705e
fix
bigPYJ1151 Jul 4, 2024
fae9534
Fix
bigPYJ1151 Jul 4, 2024
6cfbedf
fix conflict
bigPYJ1151 Jul 25, 2024
51d0eda
disable KMP setting dumping
bigPYJ1151 Jul 5, 2024
ddcae0d
trigger
bigPYJ1151 Jul 5, 2024
331b4f5
trigger
bigPYJ1151 Jul 5, 2024
8af42c0
trigger
bigPYJ1151 Jul 5, 2024
f61b9a8
trigger
bigPYJ1151 Jul 9, 2024
c3d028e
Update
bigPYJ1151 Jul 11, 2024
6be36d6
Add shm op
bigPYJ1151 Jul 12, 2024
f1fbb24
Revert "Add shm op"
bigPYJ1151 Jul 12, 2024
4a13516
Add IPEX Allreduce
bigPYJ1151 Jul 12, 2024
ecd2cac
remove prune
bigPYJ1151 Jul 15, 2024
6fb8f90
update IPEX
bigPYJ1151 Jul 16, 2024
43766b5
update doc
bigPYJ1151 Jul 19, 2024
eafb75f
retrigger
bigPYJ1151 Jul 22, 2024
cfbeae5
fix comments
bigPYJ1151 Jul 22, 2024
3f7384f
Update csrc/cpu/utils.cpp
bigPYJ1151 Jul 22, 2024
cc4b330
fix doc
bigPYJ1151 Jul 22, 2024
3268497
retrigger
bigPYJ1151 Jul 22, 2024
63e085c
retrigger
bigPYJ1151 Jul 24, 2024
7f48a4f
remove assert
bigPYJ1151 Jul 25, 2024
c87f0ab
add time out
bigPYJ1151 Jul 26, 2024
a5a13cd
Revert "add time out"
bigPYJ1151 Jul 26, 2024
aceded9
using random dataset
bigPYJ1151 Jul 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,38 @@
set -ex

# Try building the docker image
docker build -t cpu-test -f Dockerfile.cpu .
docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
numactl -C 48-95 -N 1 docker build -t cpu-test -f Dockerfile.cpu .
numactl -C 48-95 -N 1 docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .

# Setup cleanup
remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; }
trap remove_docker_container EXIT
remove_docker_container

# Run the image
# Run the image, setting --shm-size=4g for tensor parallel.
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test-avx2 cpu-test-avx2
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2

# offline inference
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"

# Run basic model test
docker exec cpu-test bash -c "cd tests;
docker exec cpu-test bash -c "
pip install pytest Pillow protobuf
cd ../
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported

# online inference
docker exec cpu-test bash -c "
export VLLM_CPU_KVCACHE_SPACE=10
export VLLM_CPU_OMP_THREADS_BIND=48-92
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--dataset-name random \
--model facebook/opt-125m \
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"
9 changes: 5 additions & 4 deletions Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

FROM ubuntu:22.04 AS cpu-test-1

RUN apt-get update -y \
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \
RUN apt-get update -y \
&& apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12

# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
Expand All @@ -13,8 +13,9 @@ RUN pip install intel-openmp

ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD"

RUN echo 'ulimit -c 0' >> ~/.bashrc

RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl

RUN pip install --upgrade pip \
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy
Expand All @@ -25,7 +26,7 @@ COPY ./ /workspace/vllm

WORKDIR /workspace/vllm

RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/test/cpu

# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
ARG VLLM_CPU_DISABLE_AVX512
Expand Down
4 changes: 4 additions & 0 deletions cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ endif()

message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")

list(APPEND LIBS "numa")


#
# Define extension targets
Expand All @@ -95,6 +97,7 @@ set(VLLM_EXT_SRC
"csrc/cpu/activation.cpp"
"csrc/cpu/attention.cpp"
"csrc/cpu/cache.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp")
Expand All @@ -104,6 +107,7 @@ define_gpu_extension_target(
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
LIBRARIES ${LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI
Expand Down
7 changes: 7 additions & 0 deletions csrc/cpu/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include <torch/library.h>

void init_cpu_threads_env(const std::string& cpu_ids);

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops

Expand Down Expand Up @@ -107,4 +109,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
// CPU utils
utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
65 changes: 65 additions & 0 deletions csrc/cpu/utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include <numa.h>
#include <unistd.h>
#include <string>
#include <sched.h>

#include "cpu_types.hpp"

void init_cpu_threads_env(const std::string& cpu_ids) {
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
TORCH_CHECK(omp_cpu_mask->size > 0);
std::vector<int> omp_cpu_ids;
omp_cpu_ids.reserve(omp_cpu_mask->size);

constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp);

for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) {
unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size];
int i = 0;
while (group_mask) {
if (group_mask & 1) {
omp_cpu_ids.emplace_back(offset + i);
}
++i;
group_mask >>= 1;
}
}

// Memory node binding
if (numa_available() != -1) {
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
bitmask* src_mask = numa_get_membind();

int pid = getpid();

// move all existing pages to the specified numa node.
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
int page_num = numa_migrate_pages(pid, src_mask, mask);
if (page_num == -1) {
TORCH_CHECK(false,
"numa_migrate_pages failed. errno: " + std::to_string(errno));
}

// restrict memory allocation node.
numa_set_membind(mask);
numa_set_strict(1);
}

// OMP threads binding
omp_set_num_threads((int)omp_cpu_ids.size());
torch::set_num_threads((int)omp_cpu_ids.size());
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
#pragma omp parallel for schedule(static, 1)
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size);
size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size);
CPU_ZERO_S(size, mask);
CPU_SET_S(omp_cpu_ids[i], size, mask);
sched_setaffinity(0, sizeof(cpu_set_t), mask);
CPU_FREE(mask);
}

numa_free_nodemask(omp_cpu_mask);
}
55 changes: 47 additions & 8 deletions docs/source/getting_started/cpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Table of contents:
#. :ref:`Requirements <cpu_backend_requirements>`
#. :ref:`Quick start using Dockerfile <cpu_backend_quick_start_dockerfile>`
#. :ref:`Build from source <build_cpu_backend_from_source>`
#. :ref:`Related runtime environment variables <env_intro>`
#. :ref:`Intel Extension for PyTorch <ipex_guidance>`
#. :ref:`Performance tips <cpu_backend_performance_tips>`

Expand Down Expand Up @@ -47,7 +48,7 @@ Build from source
.. code-block:: console

$ sudo apt-get update -y
$ sudo apt-get install -y gcc-12 g++-12
$ sudo apt-get install -y gcc-12 g++-12 libnuma-dev
$ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12

- Second, install Python packages for vLLM CPU backend building:
Expand All @@ -71,22 +72,27 @@ Build from source

- If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building.

.. _env_intro:

Related runtime environment variables
-------------------------------------

- ``VLLM_CPU_KVCACHE_SPACE``: specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.

- ``VLLM_CPU_OMP_THREADS_BIND``: specify the CPU cores dedicated to the OpenMP threads. For example, ``VLLM_CPU_OMP_THREADS_BIND=0-31`` means there will be 32 OpenMP threads bound on 0-31 CPU cores. ``VLLM_CPU_OMP_THREADS_BIND=0-31|32-63`` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores.

.. _ipex_guidance:

Intel Extension for PyTorch
---------------------------

- `Intel Extension for PyTorch (IPEX) <https://github.com/intel/intel-extension-for-pytorch>`_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware.

- IPEX after the ``2.3.0`` can be enabled in the CPU backend by default if it is installed.

.. _cpu_backend_performance_tips:

Performance tips
-----------------

- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.

- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run:

.. code-block:: console
Expand All @@ -96,11 +102,44 @@ Performance tips
$ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD
$ python examples/offline_inference.py # run vLLM

- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription.
- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP:

- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading.
.. code-block:: console

$ export VLLM_CPU_KVCACHE_SPACE=40
$ export VLLM_CPU_OMP_THREADS_BIND=0-29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Will the API server automatically use CPU 30 and 31?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really appreciate this simplification. However, can we further set this env variable internally in vLLM so that users don't have to care about it? Just wondering because it's still not super easy to me.

For example, users may have the following questions:

  1. Can I use this arg to control the number of CPU cores I'd like to allocate to vLLM?
  2. How does this arg relate to the vLLM performance? Allocating more CPUs will improve the performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Will the API server automatically use CPU 30 and 31?

Yes, CPU 30 and 31 are reserved for non-openMP threads (e.g., python threads, asyncio event loop, ...), and leveraged by OS scheduler automatically.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really appreciate this simplification. However, can we further set this env variable internally in vLLM so that users don't have to care about it? Just wondering because it's still not super easy to me.

For example, users may have the following questions:

  1. Can I use this arg to control the number of CPU cores I'd like to allocate to vLLM?
  2. How does this arg relate to the vLLM performance? Allocating more CPUs will improve the performance?

Yes, fully automatically setting is the best solution. It requires to detect the topology of CPU cores and memory nodes. We also want to achieve such out-of-box usage.

The VLLM_CPU_OMP_THREADS_BIND controls the openMP thread behavior of the model inference, including thread number, thread affinity (pin a inference thread on a fixed CPU core), memory allocation policy (only allocate memory from the closest memory node). We have added two performance tips about this arg for platforms with hyper-threading or multi-socket configuration.

For platforms without hyper-threading or multi-socket, allocating more CPUs for model inference will improve the performance theoretically.

$ vllm serve facebook/opt-125m

- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using ``VLLM_CPU_OMP_THREADS_BIND``. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores:

.. code-block:: console

- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores and memory nodes, to avoid the remote memory node access. ``numactl`` is an useful tool for CPU core and memory binding on NUMA platform. Besides, ``--cpuset-cpus`` and ``--cpuset-mems`` arguments of ``docker run`` are also useful.
$ lscpu -e # check the mapping between logical CPU cores and physical CPU cores

# The "CPU" column means the logical CPU core IDs, and the "CORE" column means the physical core IDs. On this platform, two logical cores are sharing one physical core.
CPU NODE SOCKET CORE L1d:L1i:L2:L3 ONLINE MAXMHZ MINMHZ MHZ
0 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
1 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
2 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
3 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
4 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
5 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
6 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
7 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000
8 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000
9 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000
10 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000
11 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000
12 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000
13 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000
14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000
15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000

# On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15
$ export VLLM_CPU_OMP_THREADS_BIND=0-7
$ python examples/offline_inference.py

- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using ``VLLM_CPU_OMP_THREADS_BIND`` to avoid cross NUMA node memory access.



4 changes: 2 additions & 2 deletions requirements-cpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
-r requirements-common.txt

# Dependencies for x86_64 CPUs
torch == 2.3.1+cpu; platform_machine != "ppc64le"
torchvision == 0.18.1+cpu; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
torch == 2.4.0; platform_machine != "ppc64le"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bigPYJ1151 and @WoosukKwon , this change missed the suffix +cpu for torch version which leads to the build failure of the CPU target.
Please take a look at #6931 .
Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also does something more subtle: the required version in requirements-build.txt and pyproject.toml is still 2.3.1, causing building this with pip install . (i.e. pep517 style builds) result in a broken build:

WARNING 08-01 17:51:34 _custom_ops.py:14] Failed to import from vllm._C with ImportError('/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/_C.abi3.so: undefined symbol: torch::jit::parseSchema(std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)')
...

Querying the engine then results in an error: AttributeError: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache'.

Full traceback follows:

(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 _custom_ops.py:39] Error in calling custom op reshape_and_cache: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache'
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 _custom_ops.py:39] Possibly you have built or installed an obsolete version of vllm.
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 _custom_ops.py:39] Please try a clean build and install of vllm,or remove old built files such as vllm/*cpython*.so and build/ .
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226] Exception in worker VllmWorkerProcess while processing method execute_model: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache', Traceback (most recent call last):
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]              ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/worker/worker_base.py", line 273, in execute_model
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     output = self.model_runner.execute_model(
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/worker/cpu_model_runner.py", line 374, in execute_model
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     hidden_states = model_executable(**execute_model_kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/model_executor/models/opt.py", line 322, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     hidden_states = self.model(input_ids, positions, kv_caches,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/model_executor/models/opt.py", line 291, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self.decoder(input_ids,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/model_executor/models/opt.py", line 260, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/model_executor/models/opt.py", line 162, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     hidden_states = self.self_attn(hidden_states=hidden_states,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/model_executor/models/opt.py", line 105, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/attention/layer.py", line 97, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self.impl.forward(query,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/attention/backends/torch_sdpa.py", line 177, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     PagedAttention.write_to_paged_cache(key, value, key_cache,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/attention/ops/paged_attn.py", line 75, in write_to_paged_cache
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     ops.reshape_and_cache(
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/_custom_ops.py", line 40, in wrapper
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     raise e
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/_custom_ops.py", line 31, in wrapper
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return fn(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/_custom_ops.py", line 425, in reshape_and_cache
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/_ops.py", line 1170, in __getattr__
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     raise AttributeError(
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226] AttributeError: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache'
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]
ERROR 08-01 17:51:47 async_llm_engine.py:56] Engine background task failed
ERROR 08-01 17:51:47 async_llm_engine.py:56] Traceback (most recent call last):
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 46, in _log_task_completion
ERROR 08-01 17:51:47 async_llm_engine.py:56]     return_value = task.result()
ERROR 08-01 17:51:47 async_llm_engine.py:56]                    ^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 641, in run_engine_loop
ERROR 08-01 17:51:47 async_llm_engine.py:56]     result = task.result()
ERROR 08-01 17:51:47 async_llm_engine.py:56]              ^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 584, in engine_step
ERROR 08-01 17:51:47 async_llm_engine.py:56]     request_outputs = await self.engine.step_async(virtual_engine)
ERROR 08-01 17:51:47 async_llm_engine.py:56]                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 253, in step_async
ERROR 08-01 17:51:47 async_llm_engine.py:56]     output = await self.model_executor.execute_model_async(
ERROR 08-01 17:51:47 async_llm_engine.py:56]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/executor/cpu_executor.py", line 305, in execute_model_async
ERROR 08-01 17:51:47 async_llm_engine.py:56]     output = await make_async(self.execute_model
ERROR 08-01 17:51:47 async_llm_engine.py:56]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/.pyenv/versions/3.11.9/lib/python3.11/concurrent/futures/thread.py", line 58, in run
ERROR 08-01 17:51:47 async_llm_engine.py:56]     result = self.fn(*self.args, **self.kwargs)
ERROR 08-01 17:51:47 async_llm_engine.py:56]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/executor/cpu_executor.py", line 223, in execute_model
ERROR 08-01 17:51:47 async_llm_engine.py:56]     output = self.driver_method_invoker(self.driver_worker,
ERROR 08-01 17:51:47 async_llm_engine.py:56]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/executor/cpu_executor.py", line 362, in _async_driver_method_invoker
ERROR 08-01 17:51:47 async_llm_engine.py:56]     return driver.execute_method(method, *args, **kwargs).get()
ERROR 08-01 17:51:47 async_llm_engine.py:56]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/executor/multiproc_worker_utils.py", line 58, in get
ERROR 08-01 17:51:47 async_llm_engine.py:56]     raise self.result.exception
ERROR 08-01 17:51:47 async_llm_engine.py:56] AttributeError: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache'
Exception in callback _log_task_completion(error_callback=>)() at /home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py:36
handle: >)() at /home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py:36>

torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
3 changes: 3 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
pynccl_comm = self.pynccl_comm
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
elif input_.is_cpu:
import intel_extension_for_pytorch as ipex
ipex.distributed.all_reduce(input_, group=self.device_group)
else:
torch.distributed.all_reduce(input_, group=self.device_group)
return input_
Expand Down
2 changes: 0 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,6 @@ def _get_executor_cls(
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu":
assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.device_config.device_type == "openvino":
Expand Down
8 changes: 7 additions & 1 deletion vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_OPENVINO_KVCACHE_SPACE: int = 0
VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None
VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False
Expand Down Expand Up @@ -241,11 +242,16 @@ def get_default_config_root():
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),

# CPU key-value cache space
# (CPU backend only) CPU key-value cache space.
# default is 4GB
"VLLM_CPU_KVCACHE_SPACE":
lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")),

# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.
"VLLM_CPU_OMP_THREADS_BIND":
lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"),

# OpenVINO key-value cache space
# default is 4GB
"VLLM_OPENVINO_KVCACHE_SPACE":
Expand Down
Loading
Loading