Skip to content

Commit

Permalink
Merge branch 'main' into remove-deprecations
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 committed Jan 24, 2025
2 parents a737a96 + ab5bbf5 commit 8c88cbf
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 51 deletions.
7 changes: 5 additions & 2 deletions .buildkite/check-wheel-size.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import sys
import zipfile

# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250))
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 300 MiB
# Note that we have 400 MiB quota, please use it wisely.
# See https://github.com/pypi/support/issues/3792 .
# Please also sync the value with the one in Dockerfile.
VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 300))


def print_top_10_largest_files(zip_file):
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954
GIT_TAG 0aff05f577e8a10086066a00618609199b25231d
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ RUN --mount=type=cache,target=/root/.cache/ccache \

# Check the size of the wheel if RUN_WHEEL_CHECK is true
COPY .buildkite/check-wheel-size.py check-wheel-size.py
# Default max size of the wheel is 250MB
ARG VLLM_MAX_SIZE_MB=250
# sync the default value with .buildkite/check-wheel-size.py
ARG VLLM_MAX_SIZE_MB=300
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
ARG RUN_WHEEL_CHECK=true
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
Expand Down
15 changes: 10 additions & 5 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ async def async_request_tgi(
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
params = {
"best_of": request_func_input.best_of,
"max_new_tokens": request_func_input.output_len,
Expand Down Expand Up @@ -123,7 +124,8 @@ async def async_request_trt_llm(
api_url = request_func_input.api_url
assert api_url.endswith("generate_stream")

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1
payload = {
"accumulate_tokens": True,
Expand Down Expand Up @@ -187,7 +189,8 @@ async def async_request_deepspeed_mii(
request_func_input: RequestFuncInput,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1

payload = {
Expand Down Expand Up @@ -235,7 +238,8 @@ async def async_request_openai_completions(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
Expand Down Expand Up @@ -333,7 +337,8 @@ async def async_request_openai_chat_completions(
"chat/completions"
), "OpenAI Chat Completions API URL must end with 'chat/completions'."

async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
Expand Down
32 changes: 12 additions & 20 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def sample_sonnet_requests(
return sampled_requests


def sample_mmmu_pro_vision_requests(
def sample_vision_arena_requests(
dataset,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
Expand All @@ -212,13 +212,7 @@ def sample_mmmu_pro_vision_requests(
if len(sampled_requests) == num_requests:
break

# MMMU-Pro vision direct prompt
# Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5
prompt = (
"Answer with the option letter from the given choices directly. "
"The last line of your response should be of the following "
"format: 'Answer: $LETTER' (without quotes) where LETTER is one of "
"options.")
prompt = data["turns"][0][0]['content']

prompt_token_ids = tokenizer(prompt).input_ids
if fixed_output_len is None:
Expand All @@ -230,10 +224,10 @@ def sample_mmmu_pro_vision_requests(
output_len = fixed_output_len

assert isinstance(
data["image"],
data["images"][0],
Image), ("Input image format must be `PIL.Image.Image`, "
f"given {type(data['image'])}.")
image: Image = data["image"]
image: Image = data["images"][0]
image = image.convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
Expand All @@ -252,27 +246,25 @@ def sample_mmmu_pro_vision_requests(

def sample_hf_requests(
dataset_path: str,
dataset_subset: str,
dataset_subset: Optional[str],
dataset_split: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
random_seed: int,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]:

# Special case for MMMU-Pro vision dataset
if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision':
assert dataset_split == "test"
# Special case for vision_arena dataset
if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \
and dataset_subset is None:
assert dataset_split == "train"
dataset = load_dataset(dataset_path,
name=dataset_subset,
split=dataset_split,
streaming=True)
assert "image" in dataset.features, (
"MMMU/MMMU_Pro vision dataset must have 'image' column.")
filter_func = lambda x: isinstance(x["image"], Image)
dataset = dataset.shuffle(seed=random_seed).filter(filter_func)
return sample_mmmu_pro_vision_requests(dataset, num_requests,
tokenizer, fixed_output_len)
dataset = dataset.shuffle(seed=random_seed)
return sample_vision_arena_requests(dataset, num_requests, tokenizer,
fixed_output_len)

dataset = load_dataset(dataset_path,
name=dataset_subset,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ To build and install vLLM from source, run:
```console
git clone https://github.com/vllm-project/vllm.git
cd vllm
pip install -r requirements-hpu.txt
python setup.py develop
```

Expand All @@ -68,6 +69,7 @@ Currently, the latest features and performance optimizations are developed in Ga
git clone https://github.com/HabanaAI/vllm-fork.git
cd vllm-fork
git checkout habana_main
pip install -r requirements-hpu.txt
python setup.py develop
```

Expand Down
5 changes: 4 additions & 1 deletion setup.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,10 @@ def _read_requirements(filename: str) -> List[str]:

if _is_cuda():
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C"))
ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"):
# FA3 requires CUDA 12.0 or later
ext_modules.append(
CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C"))
ext_modules.append(CMakeExtension(name="vllm.cumem_allocator"))

if _build_custom_ops():
Expand Down
11 changes: 6 additions & 5 deletions tests/kernels/test_cascade_flash_attn.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import (cascade_attention,
merge_attn_states)
from vllm.vllm_flash_attn import flash_attn_varlen_func
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 192, 256]
Expand Down Expand Up @@ -91,10 +93,9 @@ def test_cascade(
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")

current_platform.seed_everything(0)

Expand Down
21 changes: 10 additions & 11 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import torch

from vllm.platforms import current_platform
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
Expand Down Expand Up @@ -95,10 +97,9 @@ def test_flash_attn_with_paged_kv(
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")
if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")

current_platform.seed_everything(0)
num_seqs = len(kv_lens)
Expand Down Expand Up @@ -182,11 +183,9 @@ def test_varlen_with_paged_kv(
fa_version: int,
) -> None:
torch.set_default_device("cuda")
if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6)
or torch.cuda.get_device_capability() == (8, 9)):
pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to "
"insufficient shared memory for some shapes")

if not is_fa_version_supported(fa_version):
pytest.skip(f"Flash attention version {fa_version} not supported due "
f"to: \"{fa_version_unsupported_reason(fa_version)}\"")
current_platform.seed_everything(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
Expand Down
14 changes: 11 additions & 3 deletions vllm/attention/backends/flash_attn.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,20 @@
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.platforms import current_platform
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)

from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache,
is_fa_version_supported)
logger = init_logger(__name__)


class FlashAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -652,6 +655,11 @@ def __init__(
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION

if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))

assert is_fa_version_supported(self.fa_version)

def forward(
Expand Down
11 changes: 10 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.envs import VLLM_FLASH_ATTN_VERSION
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
from vllm.vllm_flash_attn import (fa_version_unsupported_reason,
flash_attn_varlen_func,
is_fa_version_supported)

logger = init_logger(__name__)


class FlashAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -143,6 +147,11 @@ def __init__(
assert VLLM_FLASH_ATTN_VERSION in [2, 3]
self.fa_version = VLLM_FLASH_ATTN_VERSION

if not is_fa_version_supported(self.fa_version):
logger.error("Cannot use FA version %d is not supported due to %s",
self.fa_version,
fa_version_unsupported_reason(self.fa_version))

assert is_fa_version_supported(self.fa_version)

def forward(
Expand Down

0 comments on commit 8c88cbf

Please sign in to comment.