diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index a7678aae54644..21deec2bba973 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -3,26 +3,38 @@ set -ex # Try building the docker image -docker build -t cpu-test -f Dockerfile.cpu . -docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu . +numactl -C 48-95 -N 1 docker build -t cpu-test -f Dockerfile.cpu . +numactl -C 48-95 -N 1 docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu . # Setup cleanup remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; } trap remove_docker_container EXIT remove_docker_container -# Run the image +# Run the image, setting --shm-size=4g for tensor parallel. docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \ - --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test + --cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \ - --cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test-avx2 cpu-test-avx2 + --cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2 # offline inference -docker exec cpu-test bash -c "python3 examples/offline_inference.py" docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test -docker exec cpu-test bash -c "cd tests; +docker exec cpu-test bash -c " pip install pytest Pillow protobuf - cd ../ pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported + +# online inference +docker exec cpu-test bash -c " + export VLLM_CPU_KVCACHE_SPACE=10 + export VLLM_CPU_OMP_THREADS_BIND=48-92 + python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m & + timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 + python3 benchmarks/benchmark_serving.py \ + --backend vllm \ + --dataset-name random \ + --model facebook/opt-125m \ + --num-prompts 20 \ + --endpoint /v1/completions \ + --tokenizer facebook/opt-125m" diff --git a/Dockerfile.cpu b/Dockerfile.cpu index f95d748f1e4be..c473ba431e680 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -2,8 +2,8 @@ FROM ubuntu:22.04 AS cpu-test-1 -RUN apt-get update -y \ - && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \ +RUN apt-get update -y \ + && apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 # https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html @@ -13,8 +13,9 @@ RUN pip install intel-openmp ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD" +RUN echo 'ulimit -c 0' >> ~/.bashrc -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl +RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl RUN pip install --upgrade pip \ && pip install wheel packaging ninja "setuptools>=49.4.0" numpy @@ -25,7 +26,7 @@ COPY ./ /workspace/vllm WORKDIR /workspace/vllm -RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu +RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/test/cpu # Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ... ARG VLLM_CPU_DISABLE_AVX512 diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 690559ee265e9..118f9b28e0ae3 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -83,6 +83,8 @@ endif() message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") +list(APPEND LIBS "numa") + # # Define extension targets @@ -95,6 +97,7 @@ set(VLLM_EXT_SRC "csrc/cpu/activation.cpp" "csrc/cpu/attention.cpp" "csrc/cpu/cache.cpp" + "csrc/cpu/utils.cpp" "csrc/cpu/layernorm.cpp" "csrc/cpu/pos_encoding.cpp" "csrc/cpu/torch_bindings.cpp") @@ -104,6 +107,7 @@ define_gpu_extension_target( DESTINATION vllm LANGUAGE CXX SOURCES ${VLLM_EXT_SRC} + LIBRARIES ${LIBS} COMPILE_FLAGS ${CXX_COMPILE_FLAGS} USE_SABI 3 WITH_SOABI diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 5be0e9810b5b9..7d549e271a30d 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -4,6 +4,8 @@ #include +void init_cpu_threads_env(const std::string& cpu_ids); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -107,4 +109,9 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } +TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { + // CPU utils + utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env); +} + REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp new file mode 100644 index 0000000000000..5782580baa861 --- /dev/null +++ b/csrc/cpu/utils.cpp @@ -0,0 +1,65 @@ +#include +#include +#include +#include + +#include "cpu_types.hpp" + +void init_cpu_threads_env(const std::string& cpu_ids) { + bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); + TORCH_CHECK(omp_cpu_mask->size > 0); + std::vector omp_cpu_ids; + omp_cpu_ids.reserve(omp_cpu_mask->size); + + constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp); + + for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) { + unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size]; + int i = 0; + while (group_mask) { + if (group_mask & 1) { + omp_cpu_ids.emplace_back(offset + i); + } + ++i; + group_mask >>= 1; + } + } + + // Memory node binding + if (numa_available() != -1) { + int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front()); + bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str()); + bitmask* src_mask = numa_get_membind(); + + int pid = getpid(); + + // move all existing pages to the specified numa node. + *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); + int page_num = numa_migrate_pages(pid, src_mask, mask); + if (page_num == -1) { + TORCH_CHECK(false, + "numa_migrate_pages failed. errno: " + std::to_string(errno)); + } + + // restrict memory allocation node. + numa_set_membind(mask); + numa_set_strict(1); + } + + // OMP threads binding + omp_set_num_threads((int)omp_cpu_ids.size()); + torch::set_num_threads((int)omp_cpu_ids.size()); + TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads()); + TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads()); +#pragma omp parallel for schedule(static, 1) + for (size_t i = 0; i < omp_cpu_ids.size(); ++i) { + cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size); + size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size); + CPU_ZERO_S(size, mask); + CPU_SET_S(omp_cpu_ids[i], size, mask); + sched_setaffinity(0, sizeof(cpu_set_t), mask); + CPU_FREE(mask); + } + + numa_free_nodemask(omp_cpu_mask); +} diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 1c97515dbecd9..7fc469e06844f 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -10,6 +10,7 @@ Table of contents: #. :ref:`Requirements ` #. :ref:`Quick start using Dockerfile ` #. :ref:`Build from source ` +#. :ref:`Related runtime environment variables ` #. :ref:`Intel Extension for PyTorch ` #. :ref:`Performance tips ` @@ -47,7 +48,7 @@ Build from source .. code-block:: console $ sudo apt-get update -y - $ sudo apt-get install -y gcc-12 g++-12 + $ sudo apt-get install -y gcc-12 g++-12 libnuma-dev $ sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 - Second, install Python packages for vLLM CPU backend building: @@ -71,6 +72,15 @@ Build from source - If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building. +.. _env_intro: + +Related runtime environment variables +------------------------------------- + +- ``VLLM_CPU_KVCACHE_SPACE``: specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. + +- ``VLLM_CPU_OMP_THREADS_BIND``: specify the CPU cores dedicated to the OpenMP threads. For example, ``VLLM_CPU_OMP_THREADS_BIND=0-31`` means there will be 32 OpenMP threads bound on 0-31 CPU cores. ``VLLM_CPU_OMP_THREADS_BIND=0-31|32-63`` means there will be 2 tensor parallel processes, 32 OpenMP threads of rank0 are bound on 0-31 CPU cores, and the OpenMP threads of rank1 are bound on 32-63 CPU cores. + .. _ipex_guidance: Intel Extension for PyTorch @@ -78,15 +88,11 @@ Intel Extension for PyTorch - `Intel Extension for PyTorch (IPEX) `_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware. -- IPEX after the ``2.3.0`` can be enabled in the CPU backend by default if it is installed. - .. _cpu_backend_performance_tips: Performance tips ----------------- -- vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. - - We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run: .. code-block:: console @@ -96,11 +102,44 @@ Performance tips $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD $ python examples/offline_inference.py # run vLLM -- vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription. +- When using the online serving, it is recommended to reserve 1-2 CPU cores for the serving framework to avoid CPU oversubscription. For example, on a platform with 32 physical CPU cores, reserving CPU 30 and 31 for the framework and using CPU 0-29 for OpenMP: -- If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading. +.. code-block:: console + + $ export VLLM_CPU_KVCACHE_SPACE=40 + $ export VLLM_CPU_OMP_THREADS_BIND=0-29 + $ vllm serve facebook/opt-125m + +- If using vLLM CPU backend on a machine with hyper-threading, it is recommended to bind only one OpenMP thread on each physical CPU core using ``VLLM_CPU_OMP_THREADS_BIND``. On a hyper-threading enabled platform with 16 logical CPU cores / 8 physical CPU cores: + +.. code-block:: console -- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores and memory nodes, to avoid the remote memory node access. ``numactl`` is an useful tool for CPU core and memory binding on NUMA platform. Besides, ``--cpuset-cpus`` and ``--cpuset-mems`` arguments of ``docker run`` are also useful. + $ lscpu -e # check the mapping between logical CPU cores and physical CPU cores + + # The "CPU" column means the logical CPU core IDs, and the "CORE" column means the physical core IDs. On this platform, two logical cores are sharing one physical core. + CPU NODE SOCKET CORE L1d:L1i:L2:L3 ONLINE MAXMHZ MINMHZ MHZ + 0 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000 + 1 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000 + 2 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000 + 3 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000 + 4 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000 + 5 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000 + 6 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000 + 7 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000 + 8 0 0 0 0:0:0:0 yes 2401.0000 800.0000 800.000 + 9 0 0 1 1:1:1:0 yes 2401.0000 800.0000 800.000 + 10 0 0 2 2:2:2:0 yes 2401.0000 800.0000 800.000 + 11 0 0 3 3:3:3:0 yes 2401.0000 800.0000 800.000 + 12 0 0 4 4:4:4:0 yes 2401.0000 800.0000 800.000 + 13 0 0 5 5:5:5:0 yes 2401.0000 800.0000 800.000 + 14 0 0 6 6:6:6:0 yes 2401.0000 800.0000 800.000 + 15 0 0 7 7:7:7:0 yes 2401.0000 800.0000 800.000 + + # On this platform, it is recommend to only bind openMP threads on logical CPU cores 0-7 or 8-15 + $ export VLLM_CPU_OMP_THREADS_BIND=0-7 + $ python examples/offline_inference.py + +- If using vLLM CPU backend on a multi-socket machine with NUMA, be aware to set CPU cores using ``VLLM_CPU_OMP_THREADS_BIND`` to avoid cross NUMA node memory access. diff --git a/requirements-cpu.txt b/requirements-cpu.txt index 754070df21c0a..a8ce104d83290 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,6 +2,6 @@ -r requirements-common.txt # Dependencies for x86_64 CPUs -torch == 2.3.1+cpu; platform_machine != "ppc64le" -torchvision == 0.18.1+cpu; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch +torch == 2.4.0; platform_machine != "ppc64le" +torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e9c6fc3a255e4..58cae46d9af27 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -296,6 +296,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: pynccl_comm = self.pynccl_comm if (pynccl_comm is not None and not pynccl_comm.disabled): pynccl_comm.all_reduce(input_) + elif input_.is_cpu: + import intel_extension_for_pytorch as ipex + ipex.distributed.all_reduce(input_, group=self.device_group) else: torch.distributed.all_reduce(input_, group=self.device_group) return input_ diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 16b7bc64a2849..93cc319f11c42 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -410,8 +410,6 @@ def _get_executor_cls( from vllm.executor.tpu_executor import TPUExecutorAsync executor_class = TPUExecutorAsync elif engine_config.device_config.device_type == "cpu": - assert distributed_executor_backend is None, ( - "Distributed execution is not supported with the CPU backend.") from vllm.executor.cpu_executor import CPUExecutorAsync executor_class = CPUExecutorAsync elif engine_config.device_config.device_type == "openvino": diff --git a/vllm/envs.py b/vllm/envs.py index 595992e51db87..f06b6d66ea6f4 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -29,6 +29,7 @@ VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_OPENVINO_KVCACHE_SPACE: int = 0 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False @@ -241,11 +242,16 @@ def get_default_config_root(): "VLLM_ATTENTION_BACKEND": lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), - # CPU key-value cache space + # (CPU backend only) CPU key-value cache space. # default is 4GB "VLLM_CPU_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), + # (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31", + # "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'. + "VLLM_CPU_OMP_THREADS_BIND": + lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"), + # OpenVINO key-value cache space # default is 4GB "VLLM_OPENVINO_KVCACHE_SPACE": diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 23e429dac7232..3229e5ad20afa 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,16 +1,21 @@ -from typing import List, Set, Tuple +import os +from functools import partial +from typing import Any, Awaitable, List, Optional, Set, Tuple, Union import torch import vllm.envs as envs from vllm.config import CacheConfig, ModelConfig, SchedulerConfig from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, - make_async) +from vllm.utils import (get_distributed_init_method, get_open_port, + get_vllm_instance_id, make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -22,46 +27,173 @@ class CPUExecutor(ExecutorBase): def _init_executor(self) -> None: assert self.device_config.device_type == "cpu" assert self.lora_config is None, "cpu backend doesn't support LoRA" + + # + # Environment variables for CPU executor + # + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + + # Disable torch async compiling which won't work with daemonic processes + os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + + # Intel OpenMP setting + ld_prealod_str = os.getenv("LD_PRELOAD", "") + if "libiomp5.so" in ld_prealod_str: + # The time(milliseconds) that a thread should wait after + # completing the execution of a parallel region, before sleeping. + os.environ['KMP_BLOCKTIME'] = "1" + # Prevents the CPU to run into low performance state + os.environ['KMP_TPAUSE'] = "0" + # Provides fine granularity parallelism + os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist" + os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist" + os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist" + + # To hint IPEX uses shared memory based AllReduce + os.environ["LOCAL_WORLD_SIZE"] = str( + self.parallel_config.tensor_parallel_size) + self.model_config = _verify_and_get_model_config(self.model_config) self.cache_config = _verify_and_get_cache_config(self.cache_config) self.scheduler_config = _verify_and_get_scheduler_config( self.scheduler_config) - # Instantiate the worker and load the model to CPU. - self._init_worker() - - def _init_worker(self): - from vllm.worker.cpu_worker import CPUWorker + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + ip = "127.0.0.1" + port = get_open_port() + self.distributed_init_method = get_distributed_init_method(ip, port) + + is_async = isinstance(self, CPUExecutorAsync) + + world_size = self.parallel_config.tensor_parallel_size + result_handler = ResultHandler() + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + self.workers = [] + + if is_async: + self.workers = [ + ProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + )) for rank in range(0, world_size) + ] + self.driver_worker = self.workers[0] + self.workers = self.workers[1:] + self.driver_method_invoker = _async_driver_method_invoker + else: + self.driver_worker = self._create_worker() + self.driver_method_invoker = _driver_method_invoker + + if world_size != 1: + self.workers = [ + ProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + )) for rank in range(1, world_size) + ] + + if world_size != 1 or is_async: + if is_async: + async_worker_list = self.workers + [self.driver_worker] + else: + async_worker_list = self.workers + self.worker_monitor = WorkerMonitor(async_worker_list, + result_handler) + result_handler.start() + self.worker_monitor.start() + + self._run_workers("init_device") + self._run_workers("load_model") + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + ): + worker_module_name = "vllm.worker.cpu_worker" + worker_class_name = "CPUWorker" + + wrapper = WorkerWrapperBase( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + ) - assert self.parallel_config.world_size == 1, ( - "CPUExecutor only supports single CPU socket currently.") + assert self.distributed_init_method is not None - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = CPUWorker( + kwargs = dict( model_config=self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, device_config=self.device_config, cache_config=self.cache_config, load_config=self.load_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, + local_rank=local_rank, + rank=rank, + distributed_init_method=self.distributed_init_method, lora_config=self.lora_config, multimodal_config=self.multimodal_config, kv_cache_dtype=self.cache_config.cache_dtype, prompt_adapter_config=self.prompt_adapter_config, - is_driver_worker=True, + is_driver_worker=rank == 0, ) - self.driver_worker.init_device() - self.driver_worker.load_model() + wrapper.init_worker(**kwargs) + + return wrapper.worker + + def _run_workers( + self, + method: str, + *args, + async_run_remote_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if async_run_remote_workers_only: + # Just return futures + return worker_outputs + + driver_worker_output = self.driver_method_invoker( + self.driver_worker, method, *args, **kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ - return self.driver_worker.determine_num_available_blocks() + return self.driver_method_invoker(self.driver_worker, + "determine_num_available_blocks") def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: @@ -74,43 +206,95 @@ def initialize_cache(self, num_gpu_blocks: int, # referred as `gpu block`. Because we want to reuse the existing block # management procedure. logger.info("# CPU blocks: %d", num_gpu_blocks) - self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) def execute_model( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - output = self.driver_worker.execute_model(execute_model_req) + if (self.parallel_config.tensor_parallel_size > 1 + and self.parallel_worker_tasks is None): + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_remote_workers_only=True, + ) + output = self.driver_method_invoker(self.driver_worker, + "execute_model", execute_model_req) return output + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + """ + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + self.driver_method_invoker(self.driver_worker, "execute_model", None) + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) + def add_lora(self, lora_request: LoRARequest) -> bool: - return self.driver_worker.add_lora(lora_request) + return all(self._run_workers("add_lora", lora_request)) def remove_lora(self, lora_id: int) -> bool: - return self.driver_worker.remove_lora(lora_id) + return all(self._run_workers("remove_lora", lora_id)) def pin_lora(self, lora_id: int) -> bool: - return self.driver_worker.pin_lora(lora_id) + assert lora_id > 0, "lora_id must be greater than 0." + return all(self._run_workers( + "pin_lora", + lora_id=lora_id, + )) def list_loras(self) -> Set[int]: - return self.driver_worker.list_loras() + return self.driver_method_invoker(self.driver_worker, "list_loras") def add_prompt_adapter( self, prompt_adapter_request: PromptAdapterRequest) -> bool: - return self.driver_worker.add_prompt_adapter(prompt_adapter_request) + return all( + self._run_workers( + "add_prompt_adapter", + prompt_adapter_request, + )) def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) + return all( + self._run_workers( + "remove_prompt_adapter", + prompt_adapter_id, + )) def list_prompt_adapters(self) -> Set[int]: - return self.driver_worker.list_prompt_adapters() + return self.driver_method_invoker(self.driver_worker, + "list_prompt_adapters") def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: - return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) + return all(self._run_workers( + "pin_prompt_adapter", + prompt_adapter_id, + )) def check_health(self) -> None: - # CPUExecutor will always be healthy as long as - # it's running. - return + """Raises an error if engine is unhealthy.""" + if self.worker_monitor is not None and not self.worker_monitor.is_alive( + ): + raise RuntimeError("Worker processes are not running") + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): @@ -118,14 +302,12 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - output = await make_async(self.driver_worker.execute_model + output = await make_async(self.execute_model )(execute_model_req=execute_model_req, ) return output async def check_health_async(self) -> None: - # CPUExecutor will always be healthy as long as - # it's running. - return + self.check_health() def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: @@ -170,3 +352,11 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: f" {kv_cache_space}, expect a positive integer value.") return config + + +def _driver_method_invoker(driver, method: str, *args, **kwargs): + return getattr(driver, method)(*args, **kwargs) + + +def _async_driver_method_invoker(driver, method: str, *args, **kwargs): + return driver.execute_method(method, *args, **kwargs).get() diff --git a/vllm/utils.py b/vllm/utils.py index 876c3bf90b02c..90be09fc7b967 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -404,27 +404,6 @@ def update_environment_variables(envs: Dict[str, str]): os.environ[k] = v -def init_kmp_env(): - if not is_cpu(): - return - - ld_prealod_str = os.getenv("LD_PRELOAD", "") - if "libiomp5.so" not in ld_prealod_str: - return - - # The time(milliseconds) that a thread should wait after completing the - # execution of a parallel region, before sleeping. - os.environ['KMP_BLOCKTIME'] = "1" - # dump settings on start up - os.environ['KMP_SETTINGS'] = "1" - # Prevents the CPU to run into low performance state - os.environ['KMP_TPAUSE'] = "0" - # Provides fine granularity parallelism - os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist" - os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist" - os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist" - - def chunk_list(lst: List[T], chunk_size: int): """Yield successive chunk_size chunks from lst.""" for i in range(0, len(lst), chunk_size): diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 83f4ba69fb728..71763c08ec45f 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -42,6 +42,7 @@ class CPUModelInput(ModelRunnerInputBase): attn_metadata: Optional["AttentionMetadata"] = None sampling_metadata: Optional["SamplingMetadata"] = None multi_modal_kwargs: Optional[Mapping[str, BatchedTensors]] = None + virtual_engine: Optional[int] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -204,8 +205,8 @@ def _prepare_prompt( attn_metadata = self.attn_backend.make_metadata( is_prompt=True, seq_lens=seq_lens, - seq_lens_tensor=None, - max_decode_seq_len=None, + seq_lens_tensor=torch.tensor([]), + max_decode_seq_len=0, num_prefills=len(seq_lens), num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, @@ -345,7 +346,7 @@ def prepare_model_input( multi_modal_kwargs=multi_modal_kwargs, ) - @torch.inference_mode() + @torch.no_grad() def execute_model( self, model_input: CPUModelInput, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 3c22c73267b7f..735d48c908d61 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -4,6 +4,7 @@ import torch import torch.distributed +import vllm.envs as envs from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, @@ -13,7 +14,7 @@ from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, init_kmp_env +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.worker.cpu_model_runner import CPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerInput) @@ -152,13 +153,18 @@ def __init__( if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." - # try to initialize intel openmp optimized tunings - init_kmp_env() - if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules init_cached_hf_modules() + + # Setup OpenMP threads affinity. + omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND + if omp_cpuids == "all": + self.local_omp_cpuid = "all" + else: + self.local_omp_cpuid = omp_cpuids.split("|")[rank] + self.model_runner: CPUModelRunner = CPUModelRunner( model_config, parallel_config, @@ -177,6 +183,9 @@ def __init__( self.cpu_cache: List[List[torch.Tensor]] def init_device(self) -> None: + if self.local_omp_cpuid != "all": + torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) + self.init_distributed_environment() # Set random seed. set_random_seed(self.model_config.seed)