Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
t3ga committed Apr 8, 2024
2 parents c0b8571 + 5727fb6 commit 6b9b105
Show file tree
Hide file tree
Showing 48 changed files with 2,203 additions and 365 deletions.
7 changes: 5 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ steps:
command: pytest -v -s engine tokenization test_sequence.py test_config.py

- label: Entrypoints Test
command: pytest -v -s entrypoints
commands:
# these tests have to be separated, because each one will allocate all posible GPU memory
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
- pytest -v -s entrypoints/test_server_oot_registration.py

- label: Examples Test
working_dir: "/vllm-workspace/examples"
Expand Down Expand Up @@ -90,7 +93,7 @@ steps:
- bash run-benchmarks.sh

- label: Documentation Build
working_dir: "/vllm-workspace/docs"
working_dir: "/vllm-workspace/test_docs/docs"
no_gpu: True
commands:
- pip install -r requirements-docs.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
matrix:
os: ['ubuntu-20.04']
python-version: ['3.8', '3.9', '3.10', '3.11']
pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt.
pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements-cuda.txt.
cuda-version: ['11.8', '12.1']

steps:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/scripts/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH

# Install requirements
$python_executable -m pip install wheel packaging
$python_executable -m pip install -r requirements.txt
$python_executable -m pip install -r requirements-cuda.txt

# Limit the number of parallel jobs to avoid OOM
export MAX_JOBS=1
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
# requirements.txt files and should be kept consistent. The ROCm torch
# versions are derived from Dockerfile.rocm
#
set(TORCH_SUPPORTED_VERSION_CUDA "2.1.2")
set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1")
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")

Expand Down
1 change: 0 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ Express your support on Twitter if vLLM aids you, or simply offer your appreciat
### Build from source

```bash
pip install -r requirements.txt
pip install -e . # This may take several minutes.
```

Expand Down
103 changes: 60 additions & 43 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# to run the OpenAI compatible server.

#################### BASE BUILD IMAGE ####################
# prepare basic build environment
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev

RUN apt-get update -y \
Expand All @@ -16,18 +17,26 @@ RUN ldconfig /usr/local/cuda-12.1/compat/
WORKDIR /workspace

# install build and runtime dependencies
COPY requirements.txt requirements.txt
COPY requirements-common.txt requirements-common.txt
COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements.txt
pip install -r requirements-cuda.txt

# install development dependencies
COPY requirements-dev.txt requirements-dev.txt
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-dev.txt

# cuda arch list used by torch
# can be useful for both `dev` and `test`
# explicitly set the list to avoid issues with torch 2.2
# see https://github.com/pytorch/pytorch/pull/123243
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
#################### BASE BUILD IMAGE ####################


#################### EXTENSION BUILD IMAGE ####################
#################### WHEEL BUILD IMAGE ####################
FROM dev AS build

# install build dependencies
Expand All @@ -38,18 +47,16 @@ RUN --mount=type=cache,target=/root/.cache/pip \
# install compiler cache to speed up compilation leveraging local or remote caching
RUN apt-get update -y && apt-get install -y ccache

# copy input files
# files and directories related to build wheels
COPY csrc csrc
COPY setup.py setup.py
COPY cmake cmake
COPY CMakeLists.txt CMakeLists.txt
COPY requirements.txt requirements.txt
COPY requirements-common.txt requirements-common.txt
COPY requirements-cuda.txt requirements-cuda.txt
COPY pyproject.toml pyproject.toml
COPY vllm/__init__.py vllm/__init__.py
COPY vllm vllm

# cuda arch list used by torch
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
# max jobs used by Ninja to build extensions
ARG max_jobs=2
ENV MAX_JOBS=${max_jobs}
Expand All @@ -61,7 +68,15 @@ ENV VLLM_INSTALL_PUNICA_KERNELS=1

ENV CCACHE_DIR=/root/.cache/ccache
RUN --mount=type=cache,target=/root/.cache/ccache \
python3 setup.py build_ext --inplace
--mount=type=cache,target=/root/.cache/pip \
python3 setup.py bdist_wheel --dist-dir=dist

# the `vllm_nccl` package must be installed from source distribution
# pip is too smart to store a wheel in the cache, and other CI jobs
# will directly use the wheel from the cache, which is not what we want.
# we need to remove it manually
RUN --mount=type=cache,target=/root/.cache/pip \
pip cache remove vllm_nccl*
#################### EXTENSION Build IMAGE ####################

#################### FLASH_ATTENTION Build IMAGE ####################
Expand All @@ -81,57 +96,59 @@ RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \

#################### FLASH_ATTENTION Build IMAGE ####################

#################### TEST IMAGE ####################
# image to run unit testing suite
FROM dev AS test

# copy pytorch extensions separately to avoid having to rebuild
# when python code changes
#################### vLLM installation IMAGE ####################
# image with vLLM installed
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
WORKDIR /vllm-workspace
# ADD is used to preserve directory structure
ADD . /vllm-workspace/
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
# Install flash attention (from pre-built wheel)

RUN apt-get update -y \
&& apt-get install -y python3-pip git vim

# Workaround for https://github.com/openai/triton/issues/2507 and
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
# this won't be needed for future versions of this docker image
# or future versions of triton.
RUN ldconfig /usr/local/cuda-12.1/compat/

# install vllm wheel first, so that torch etc will be installed
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
pip install dist/*.whl --verbose

RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
--mount=type=cache,target=/root/.cache/pip \
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
# ignore build dependencies installation because we are using pre-complied extensions
RUN rm pyproject.toml
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
#################### TEST IMAGE ####################
#################### vLLM installation IMAGE ####################


#################### RUNTIME BASE IMAGE ####################
# We used base cuda image because pytorch installs its own cuda libraries.
# However pynccl depends on cuda libraries so we had to switch to the runtime image
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
#################### TEST IMAGE ####################
# image to run unit testing suite
# note that this uses vllm installed by `pip`
FROM vllm-base AS test

# libnccl required for ray
RUN apt-get update -y \
&& apt-get install -y python3-pip
ADD . /vllm-workspace/

WORKDIR /workspace
COPY requirements.txt requirements.txt
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements.txt

# Install flash attention (from pre-built wheel)
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
pip install -r requirements-dev.txt

#################### RUNTIME BASE IMAGE ####################
# doc requires source code
# we hide them inside `test_docs/` , so that this source code
# will not be imported by other tests
RUN mkdir test_docs
RUN mv docs test_docs/
RUN mv vllm test_docs/

#################### TEST IMAGE ####################

#################### OPENAI API SERVER ####################
# openai api server alternative
FROM vllm-base AS vllm-openai

# install additional dependencies for openai api server
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate hf_transfer modelscope

COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY vllm vllm

ENV VLLM_USAGE_SOURCE production-docker-image

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
Expand Down
3 changes: 2 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include LICENSE
include requirements.txt
include requirements-common.txt
include requirements-cuda.txt
include CMakeLists.txt

recursive-include cmake *
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
Expand Down
14 changes: 12 additions & 2 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def run_to_completion(profile_dir: Optional[str] = None):
return latency

print("Warming up...")
run_to_completion(profile_dir=None)
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion(profile_dir=None)

if args.profile:
profile_dir = args.profile_result_dir
Expand All @@ -84,7 +85,12 @@ def run_to_completion(profile_dir: Optional[str] = None):
latencies = []
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None))
latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90]
percentiles = np.percentile(latencies, percentages)
print(f'Avg latency: {np.mean(latencies)} seconds')
for percentage, percentile in zip(percentages, percentiles):
print(f'{percentage}% percentile latency: {percentile} seconds')


if __name__ == '__main__':
Expand All @@ -106,9 +112,13 @@ def run_to_completion(profile_dir: Optional[str] = None):
default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters-warmup',
type=int,
default=10,
help='Number of iterations to run for warmup.')
parser.add_argument('--num-iters',
type=int,
default=3,
default=30,
help='Number of iterations to run.')
parser.add_argument('--trust-remote-code',
action='store_true',
Expand Down
3 changes: 0 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,10 @@
# documentation root, use os.path.abspath to make it absolute, like shown here.

import logging
import os
import sys

from sphinx.ext import autodoc

sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))

logger = logging.getLogger(__name__)

# -- Project information -----------------------------------------------------
Expand Down
27 changes: 27 additions & 0 deletions docs/source/models/adding_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ This document provides a high-level guide on integrating a `HuggingFace Transfor
Start by forking our `GitHub`_ repository and then :ref:`build it from source <build_from_source>`.
This gives you the ability to modify the codebase and test your model.

.. tip::
If you don't want to fork the repository and modify vLLM's codebase, please refer to the "Out-of-Tree Model Integration" section below.

1. Bring your model code
------------------------
Expand Down Expand Up @@ -94,3 +96,28 @@ This method should load the weights from the HuggingFace's checkpoint file and a
----------------------

Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/model_loader.py>`_.

6. Out-of-Tree Model Integration
--------------------------------------------

We also provide a way to integrate a model without modifying the vLLM codebase. Step 2, 3, 4 are still required, but you can skip step 1 and 5.

Just add the following lines in your code:

.. code-block:: python
from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
If you are running api server with `python -m vllm.entrypoints.openai.api_server args`, you can wrap the entrypoint with the following code:

.. code-block:: python
from vllm import ModelRegistry
from your_code import YourModelForCausalLM
ModelRegistry.register_model("YourModelForCausalLM", YourModelForCausalLM)
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
Save the above code in a file and run it with `python your_file.py args`.
16 changes: 16 additions & 0 deletions docs/source/models/engine_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,19 @@ Below, you can find an explanation of every engine argument for vLLM:
.. option:: --quantization (-q) {awq,squeezellm,None}

Method used to quantize the weights.

Async Engine Arguments
----------------------
Below are the additional arguments related to the asynchronous engine:

.. option:: --engine-use-ray

Use Ray to start the LLM engine in a separate process as the server process.

.. option:: --disable-log-requests

Disable logging requests.

.. option:: --max-log-len

Max number of prompt characters or prompt ID numbers being printed in log. Defaults to unlimited.
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ Alongside each architecture, we include some popular models that use it.
- LLaMA, LLaMA-2, Vicuna, Alpaca, Yi
- :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
- ✅︎
* - :code:`MiniCPMForCausalLM`
- MiniCPM
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
-
* - :code:`MistralForCausalLM`
- Mistral, Mistral-Instruct
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
Expand Down
Loading

0 comments on commit 6b9b105

Please sign in to comment.