diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 0412c5f37952d..e29eb78a9f945 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -2,8 +2,11 @@ import sys import zipfile -# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB -VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250)) +# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 300 MiB +# Note that we have 400 MiB quota, please use it wisely. +# See https://github.com/pypi/support/issues/3792 . +# Please also sync the value with the one in Dockerfile. +VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 300)) def print_top_10_largest_files(zip_file): diff --git a/CMakeLists.txt b/CMakeLists.txt old mode 100644 new mode 100755 index 5039ac2448f83..2f9da6fa3e1d3 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -576,7 +576,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 90eacc1af2a7c3de62ea249e929ed5faccf38954 + GIT_TAG 0aff05f577e8a10086066a00618609199b25231d GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/Dockerfile b/Dockerfile index 261f5440aee47..cb9cf0da5be65 100644 --- a/Dockerfile +++ b/Dockerfile @@ -126,8 +126,8 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ # Check the size of the wheel if RUN_WHEEL_CHECK is true COPY .buildkite/check-wheel-size.py check-wheel-size.py -# Default max size of the wheel is 250MB -ARG VLLM_MAX_SIZE_MB=250 +# sync the default value with .buildkite/check-wheel-size.py +ARG VLLM_MAX_SIZE_MB=300 ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB ARG RUN_WHEEL_CHECK=true RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index d098c110cd921..0612e8778aca5 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -51,7 +51,8 @@ async def async_request_tgi( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: params = { "best_of": request_func_input.best_of, "max_new_tokens": request_func_input.output_len, @@ -123,7 +124,8 @@ async def async_request_trt_llm( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: assert request_func_input.best_of == 1 payload = { "accumulate_tokens": True, @@ -187,7 +189,8 @@ async def async_request_deepspeed_mii( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: assert request_func_input.best_of == 1 payload = { @@ -235,7 +238,8 @@ async def async_request_openai_completions( ("completions", "profile") ), "OpenAI Completions API URL must end with 'completions' or 'profile'." - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: payload = { "model": request_func_input.model_name \ if request_func_input.model_name else request_func_input.model, @@ -333,7 +337,8 @@ async def async_request_openai_chat_completions( "chat/completions" ), "OpenAI Chat Completions API URL must end with 'chat/completions'." - async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession(trust_env=True, + timeout=AIOHTTP_TIMEOUT) as session: content = [{"type": "text", "text": request_func_input.prompt}] if request_func_input.multi_modal_content: content.append(request_func_input.multi_modal_content) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index bc026b0ec1ca6..63d2c3f7c7dd9 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -200,7 +200,7 @@ def sample_sonnet_requests( return sampled_requests -def sample_mmmu_pro_vision_requests( +def sample_vision_arena_requests( dataset, num_requests: int, tokenizer: PreTrainedTokenizerBase, @@ -212,13 +212,7 @@ def sample_mmmu_pro_vision_requests( if len(sampled_requests) == num_requests: break - # MMMU-Pro vision direct prompt - # Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5 - prompt = ( - "Answer with the option letter from the given choices directly. " - "The last line of your response should be of the following " - "format: 'Answer: $LETTER' (without quotes) where LETTER is one of " - "options.") + prompt = data["turns"][0][0]['content'] prompt_token_ids = tokenizer(prompt).input_ids if fixed_output_len is None: @@ -230,10 +224,10 @@ def sample_mmmu_pro_vision_requests( output_len = fixed_output_len assert isinstance( - data["image"], + data["images"][0], Image), ("Input image format must be `PIL.Image.Image`, " f"given {type(data['image'])}.") - image: Image = data["image"] + image: Image = data["images"][0] image = image.convert("RGB") image_data = io.BytesIO() image.save(image_data, format='JPEG') @@ -252,7 +246,7 @@ def sample_mmmu_pro_vision_requests( def sample_hf_requests( dataset_path: str, - dataset_subset: str, + dataset_subset: Optional[str], dataset_split: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, @@ -260,19 +254,17 @@ def sample_hf_requests( fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: - # Special case for MMMU-Pro vision dataset - if dataset_path == 'MMMU/MMMU_Pro' and dataset_subset == 'vision': - assert dataset_split == "test" + # Special case for vision_arena dataset + if dataset_path == 'lmarena-ai/vision-arena-bench-v0.1' \ + and dataset_subset is None: + assert dataset_split == "train" dataset = load_dataset(dataset_path, name=dataset_subset, split=dataset_split, streaming=True) - assert "image" in dataset.features, ( - "MMMU/MMMU_Pro vision dataset must have 'image' column.") - filter_func = lambda x: isinstance(x["image"], Image) - dataset = dataset.shuffle(seed=random_seed).filter(filter_func) - return sample_mmmu_pro_vision_requests(dataset, num_requests, - tokenizer, fixed_output_len) + dataset = dataset.shuffle(seed=random_seed) + return sample_vision_arena_requests(dataset, num_requests, tokenizer, + fixed_output_len) dataset = load_dataset(dataset_path, name=dataset_subset, diff --git a/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md b/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md index b4695d504b601..ae42dd0c0d08f 100644 --- a/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md +++ b/docs/source/getting_started/installation/ai_accelerator/hpu-gaudi.inc.md @@ -59,6 +59,7 @@ To build and install vLLM from source, run: ```console git clone https://github.com/vllm-project/vllm.git cd vllm +pip install -r requirements-hpu.txt python setup.py develop ``` @@ -68,6 +69,7 @@ Currently, the latest features and performance optimizations are developed in Ga git clone https://github.com/HabanaAI/vllm-fork.git cd vllm-fork git checkout habana_main +pip install -r requirements-hpu.txt python setup.py develop ``` diff --git a/setup.py b/setup.py old mode 100644 new mode 100755 index 36c89d435c7b7..ee193e4693806 --- a/setup.py +++ b/setup.py @@ -598,7 +598,10 @@ def _read_requirements(filename: str) -> List[str]: if _is_cuda(): ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) - ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) + if envs.VLLM_USE_PRECOMPILED or get_nvcc_cuda_version() >= Version("12.0"): + # FA3 requires CUDA 12.0 or later + ext_modules.append( + CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa3_C")) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): diff --git a/tests/kernels/test_cascade_flash_attn.py b/tests/kernels/test_cascade_flash_attn.py old mode 100644 new mode 100755 index 00eb927205d46..8edfde42ede74 --- a/tests/kernels/test_cascade_flash_attn.py +++ b/tests/kernels/test_cascade_flash_attn.py @@ -6,7 +6,9 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import (cascade_attention, merge_attn_states) -from vllm.vllm_flash_attn import flash_attn_varlen_func +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, + flash_attn_varlen_func, + is_fa_version_supported) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 192, 256] @@ -91,10 +93,9 @@ def test_cascade( fa_version: int, ) -> None: torch.set_default_device("cuda") - if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) - or torch.cuda.get_device_capability() == (8, 9)): - pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " - "insufficient shared memory for some shapes") + if not is_fa_version_supported(fa_version): + pytest.skip(f"Flash attention version {fa_version} not supported due " + f"to: \"{fa_version_unsupported_reason(fa_version)}\"") current_platform.seed_everything(0) diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index b22153c86b25f..0ee0bf6c6a374 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -4,8 +4,10 @@ import torch from vllm.platforms import current_platform -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache) +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, + flash_attn_varlen_func, + flash_attn_with_kvcache, + is_fa_version_supported) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] @@ -95,10 +97,9 @@ def test_flash_attn_with_paged_kv( fa_version: int, ) -> None: torch.set_default_device("cuda") - if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) - or torch.cuda.get_device_capability() == (8, 9)): - pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " - "insufficient shared memory for some shapes") + if not is_fa_version_supported(fa_version): + pytest.skip(f"Flash attention version {fa_version} not supported due " + f"to: \"{fa_version_unsupported_reason(fa_version)}\"") current_platform.seed_everything(0) num_seqs = len(kv_lens) @@ -182,11 +183,9 @@ def test_varlen_with_paged_kv( fa_version: int, ) -> None: torch.set_default_device("cuda") - if fa_version == 3 and (torch.cuda.get_device_capability() == (8, 6) - or torch.cuda.get_device_capability() == (8, 9)): - pytest.skip("Flash attention version 3 fails on 8.6 and 8.9 due to " - "insufficient shared memory for some shapes") - + if not is_fa_version_supported(fa_version): + pytest.skip(f"Flash attention version {fa_version} not supported due " + f"to: \"{fa_version_unsupported_reason(fa_version)}\"") current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py old mode 100644 new mode 100755 index 1be099283e472..4a9aa1e217365 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -18,17 +18,20 @@ get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) from vllm.envs import VLLM_FLASH_ATTN_VERSION +from vllm.logger import init_logger from vllm.multimodal import MultiModalPlaceholderMap from vllm.platforms import current_platform from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, + flash_attn_varlen_func, + flash_attn_with_kvcache, + is_fa_version_supported) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) -from vllm.vllm_flash_attn import (flash_attn_varlen_func, - flash_attn_with_kvcache, - is_fa_version_supported) +logger = init_logger(__name__) class FlashAttentionBackend(AttentionBackend): @@ -652,6 +655,11 @@ def __init__( assert VLLM_FLASH_ATTN_VERSION in [2, 3] self.fa_version = VLLM_FLASH_ATTN_VERSION + if not is_fa_version_supported(self.fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + self.fa_version, + fa_version_unsupported_reason(self.fa_version)) + assert is_fa_version_supported(self.fa_version) def forward( diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py old mode 100644 new mode 100755 index 7fe9b3a8f595a..ce83b1fac6c0b --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -10,11 +10,15 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.envs import VLLM_FLASH_ATTN_VERSION +from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.vllm_flash_attn import (flash_attn_varlen_func, +from vllm.vllm_flash_attn import (fa_version_unsupported_reason, + flash_attn_varlen_func, is_fa_version_supported) +logger = init_logger(__name__) + class FlashAttentionBackend(AttentionBackend): @@ -143,6 +147,11 @@ def __init__( assert VLLM_FLASH_ATTN_VERSION in [2, 3] self.fa_version = VLLM_FLASH_ATTN_VERSION + if not is_fa_version_supported(self.fa_version): + logger.error("Cannot use FA version %d is not supported due to %s", + self.fa_version, + fa_version_unsupported_reason(self.fa_version)) + assert is_fa_version_supported(self.fa_version) def forward(