From c7ddd182ad15623c3d90fe7c88e6d8a1f4011eba Mon Sep 17 00:00:00 2001 From: lilyliu Date: Mon, 1 Jul 2024 16:23:05 -0700 Subject: [PATCH 01/13] logits_soft_cap for gemma2 in flashinfer --- tests/models/test_models.py | 6 ++++++ vllm/attention/backends/flashinfer.py | 12 ++++++++---- vllm/worker/model_runner.py | 14 ++++++++++++-- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 4cd2cb665c8f0..9dff9d55df46e 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -5,6 +5,8 @@ Run `pytest tests/models/test_models.py`. """ +import os + import pytest from .utils import check_outputs_equal @@ -20,6 +22,7 @@ # "allenai/OLMo-1B", # Broken "bigcode/starcoder2-3b", "google/gemma-1.1-2b-it", + "google/gemma-2-9b" ] @@ -40,6 +43,9 @@ def test_models( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) + if "gemma-2" in model: + os.environ['VLLM_ATTENTION_BACKEND'] = "FLASHINFER" + with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 4ecac7379c7f6..a332f4e052028 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -102,6 +102,8 @@ class FlashInferMetadata(AttentionMetadata): # The data type of the paged kv cache data_type: torch.dtype = None device: torch.device = torch.device("cuda") + # Only used by gemma2 model + logits_soft_cap: Optional[float] = None def __post_init__(self): # Refer to @@ -269,9 +271,11 @@ def forward( else: assert prefill_meta is not None assert prefill_meta.prefill_wrapper is not None - output = prefill_meta.prefill_wrapper.forward(query, - kv_cache, - causal=True) + output = prefill_meta.prefill_wrapper.forward( + query, + kv_cache, + logits_soft_cap=attn_metadata.logits_soft_cap, + causal=True) else: assert attn_metadata.decode_metadata is not None assert attn_metadata.decode_metadata.decode_wrapper is not None @@ -279,5 +283,5 @@ def forward( query, kv_cache, sm_scale=self.scale, - ) + logits_soft_cap=attn_metadata.logits_soft_cap) return output.view(num_tokens, hidden_size) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 942063677a427..51cb8276188e8 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -659,6 +659,16 @@ def _prepare_model_input_tensors( dtype=torch.long, device=self.device) + logits_soft_cap = getattr(self.model_config.hf_config, + 'final_logit_softcapping', None) + if logits_soft_cap is not None and self.attn_backend.get_name( + ) != "flashinfer": + logger.warning(("Please use Flashinfer backend for models with", + "logits_soft_cap (i.e., Gemma-2).", + " Otherwise, the output might be wrong.", + " Set Flashinfer backend by ", + "export VLLM_ATTENTION_BACKEND=FLASHINFER.")) + if self.attn_backend.get_name() == "flashinfer": if len(paged_kv_indptr) > 0: paged_kv_indices_tensor = torch.tensor(paged_kv_indices, @@ -676,7 +686,6 @@ def _prepare_model_input_tensors( kv_cache_dtype = get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype) - attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, slot_mapping=slot_mapping_tensor, @@ -697,7 +706,8 @@ def _prepare_model_input_tensors( query_start_loc=query_start_loc, device=self.device, data_type=kv_cache_dtype, - use_cuda_graph=use_captured_graph) + use_cuda_graph=use_captured_graph, + logits_soft_cap=logits_soft_cap) else: attn_metadata = self.attn_backend.make_metadata( From ceb7a16e4dd7c87d291736156986800843955c35 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Mon, 1 Jul 2024 16:35:14 -0700 Subject: [PATCH 02/13] separate tests --- tests/models/test_gemma2.py | 42 +++++++++++++++++++++++++++++++++++++ tests/models/test_models.py | 8 +------ vllm/attention/selector.py | 2 +- 3 files changed, 44 insertions(+), 8 deletions(-) create mode 100644 tests/models/test_gemma2.py diff --git a/tests/models/test_gemma2.py b/tests/models/test_gemma2.py new file mode 100644 index 0000000000000..e11527a056fee --- /dev/null +++ b/tests/models/test_gemma2.py @@ -0,0 +1,42 @@ +"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. + +Run `pytest tests/models/test_mistral.py`. +""" +import pytest +import os + +from .utils import check_logprobs_close + +MODELS = [ + "google/gemma-2-9b" +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + os.environ['VLLM_ATTENTION_BACKEND'] = "FLASHINFER" + # TODO(sang): Sliding window should be tested separately. + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 9dff9d55df46e..b10d547e5c520 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -5,8 +5,6 @@ Run `pytest tests/models/test_models.py`. """ -import os - import pytest from .utils import check_outputs_equal @@ -22,7 +20,6 @@ # "allenai/OLMo-1B", # Broken "bigcode/starcoder2-3b", "google/gemma-1.1-2b-it", - "google/gemma-2-9b" ] @@ -43,9 +40,6 @@ def test_models( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - if "gemma-2" in model: - os.environ['VLLM_ATTENTION_BACKEND'] = "FLASHINFER" - with vllm_runner(model, dtype=dtype) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) @@ -68,4 +62,4 @@ def test_model_print( # This test is for verifying whether the model's extra_repr # can be printed correctly. print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) + model_runner.model) \ No newline at end of file diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 851bf52a505ee..32d2fdb3fafbf 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -78,7 +78,7 @@ def get_attn_backend( elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") logger.warning(("Flashinfer will be stuck on llma-2-7b," - " please avoid using Flashinfer as the" + " please avoid using Flashinfer as the " "backend when running on llma-2-7b.")) from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend From afa1ef3ae8675f9010748957dd30aa3372ca2d2a Mon Sep 17 00:00:00 2001 From: lilyliu Date: Mon, 1 Jul 2024 16:37:24 -0700 Subject: [PATCH 03/13] minor --- tests/models/test_gemma2.py | 11 +++++------ tests/models/test_models.py | 2 +- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/models/test_gemma2.py b/tests/models/test_gemma2.py index e11527a056fee..b2210002f6e0c 100644 --- a/tests/models/test_gemma2.py +++ b/tests/models/test_gemma2.py @@ -1,15 +1,14 @@ -"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. +"""Compare the outputs of HF and vLLM for Gemma2 models using greedy sampling. -Run `pytest tests/models/test_mistral.py`. +Run `pytest tests/models/test_gemma2.py`. """ -import pytest import os +import pytest + from .utils import check_logprobs_close -MODELS = [ - "google/gemma-2-9b" -] +MODELS = ["google/gemma-2-9b"] @pytest.mark.parametrize("model", MODELS) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index b10d547e5c520..4cd2cb665c8f0 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -62,4 +62,4 @@ def test_model_print( # This test is for verifying whether the model's extra_repr # can be printed correctly. print(vllm_model.model.llm_engine.model_executor.driver_worker. - model_runner.model) \ No newline at end of file + model_runner.model) From 6426e9b2a54ee64d76b4b271c3d1be3d09e5ee66 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Mon, 1 Jul 2024 16:43:35 -0700 Subject: [PATCH 04/13] format --- vllm/worker/model_runner.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 51cb8276188e8..dc906bc993b23 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -663,11 +663,11 @@ def _prepare_model_input_tensors( 'final_logit_softcapping', None) if logits_soft_cap is not None and self.attn_backend.get_name( ) != "flashinfer": - logger.warning(("Please use Flashinfer backend for models with", - "logits_soft_cap (i.e., Gemma-2).", - " Otherwise, the output might be wrong.", - " Set Flashinfer backend by ", - "export VLLM_ATTENTION_BACKEND=FLASHINFER.")) + logger.warning("Please use Flashinfer backend for models with" + "logits_soft_cap (i.e., Gemma-2)." + " Otherwise, the output might be wrong." + " Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") if self.attn_backend.get_name() == "flashinfer": if len(paged_kv_indptr) > 0: From 8423d03b9f640d6ff325940b7127d32c2965d321 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Mon, 1 Jul 2024 16:44:01 -0700 Subject: [PATCH 05/13] format --- vllm/worker/model_runner.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index dc906bc993b23..e3327356aff98 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -664,10 +664,10 @@ def _prepare_model_input_tensors( if logits_soft_cap is not None and self.attn_backend.get_name( ) != "flashinfer": logger.warning("Please use Flashinfer backend for models with" - "logits_soft_cap (i.e., Gemma-2)." - " Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + "logits_soft_cap (i.e., Gemma-2)." + " Otherwise, the output might be wrong." + " Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") if self.attn_backend.get_name() == "flashinfer": if len(paged_kv_indptr) > 0: From 54f85480381a8d861bd28cd5d5f4e7220db6d60e Mon Sep 17 00:00:00 2001 From: lilyliu Date: Tue, 2 Jul 2024 13:10:37 -0700 Subject: [PATCH 06/13] add flashinfer unit test, update error message --- .buildkite/test-pipeline.yaml | 7 +- tests/kernels/test_flashinfer.py | 258 +++++++++++++++++++++++++++++++ vllm/worker/model_runner.py | 10 +- 3 files changed, 268 insertions(+), 7 deletions(-) create mode 100644 tests/kernels/test_flashinfer.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d96e3c6d192e2..10e701ca3b06d 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -107,12 +107,15 @@ steps: - label: Kernels Test %N #mirror_hardwares: [amd] - command: pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + command: + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl + - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 - label: Models Test #mirror_hardwares: [amd] commands: + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl - pytest -v -s models -m \"not vlm\" - label: Vision Language Models Test @@ -223,7 +226,7 @@ steps: - pytest -v -s distributed/test_custom_all_reduce.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - pytest -v -s -x lora/test_mixtral.py diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py new file mode 100644 index 0000000000000..e1ec007835116 --- /dev/null +++ b/tests/kernels/test_flashinfer.py @@ -0,0 +1,258 @@ +from typing import List, Optional, Tuple + +import pytest +import torch +import flashinfer + +NUM_HEADS = [(16, 16), (32, 8), (64, 8)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] +DTYPES = [torch.float16, torch.bfloat16] +NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation. + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: List[int], + kv_lens: List[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs: List[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = torch.triu(empty_mask, + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask + if soft_cap is not None: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 10.0, 20.0, 30.0]) +@torch.inference_mode +def test_flashinfer_decode_with_paged_kv( + kv_lens: List[int], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float] +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) + key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) + value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") + wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + data_type=dtype + ) + + output = wrapper.forward( + query, key_value_cache, logits_soft_cap=soft_cap + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" + + + + +@pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 10.0, 20.0, 30.0]) +@torch.inference_mode +def test_flashinfer_prefill_with_paged_kv( + seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float] +) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + key_value_cache = torch.randn( + NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype + ) + key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) + value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) + + # Normalize the scale of the key and value caches to mitigate + # numerical instability. + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + qo_indptr = [0] + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + qo_indptr.append(qo_indptr[-1] + query_lens[i]) + + qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32) + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, "NHD" + ) + wrapper.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + ) + + output = wrapper.forward( + query, key_value_cache, + logits_soft_cap=soft_cap, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap + ) + assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e3327356aff98..70aeb12952a95 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -663,11 +663,11 @@ def _prepare_model_input_tensors( 'final_logit_softcapping', None) if logits_soft_cap is not None and self.attn_backend.get_name( ) != "flashinfer": - logger.warning("Please use Flashinfer backend for models with" - "logits_soft_cap (i.e., Gemma-2)." - " Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + raise ValueError("Please use Flashinfer backend for models with" + "logits_soft_cap (i.e., Gemma-2)." + " Otherwise, the output might be wrong." + " Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") if self.attn_backend.get_name() == "flashinfer": if len(paged_kv_indptr) > 0: From c0f298b9af8b96bad36d4f9f72ebd38581ef3350 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Tue, 2 Jul 2024 13:26:48 -0700 Subject: [PATCH 07/13] format --- tests/kernels/test_flashinfer.py | 148 ++++++++++++++----------------- 1 file changed, 69 insertions(+), 79 deletions(-) diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index e1ec007835116..5211be6aef009 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -1,8 +1,8 @@ from typing import List, Optional, Tuple +import flashinfer import pytest import torch -import flashinfer NUM_HEADS = [(16, 16), (32, 8), (64, 8)] HEAD_SIZES = [128, 256] @@ -71,16 +71,13 @@ def ref_paged_attn( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 10.0, 20.0, 30.0]) +@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @torch.inference_mode -def test_flashinfer_decode_with_paged_kv( - kv_lens: List[int], - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, - soft_cap: Optional[float] -) -> None: +def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], + num_heads: Tuple[int, + int], head_size: int, + dtype: torch.dtype, block_size: int, + soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) num_seqs = len(kv_lens) @@ -91,9 +88,12 @@ def test_flashinfer_decode_with_paged_kv( scale = head_size**-0.5 query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) - key_value_cache = torch.randn( - NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype - ) + key_value_cache = torch.randn(NUM_BLOCKS, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) @@ -102,7 +102,7 @@ def test_flashinfer_decode_with_paged_kv( NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - + kv_indptr = [0] kv_indices = [] kv_last_page_lens = [] @@ -114,62 +114,52 @@ def test_flashinfer_decode_with_paged_kv( kv_indptr.append(kv_indptr[-1] + num_blocks) kv_last_page_len = seq_len % block_size if kv_last_page_len == 0: - kv_last_page_len = block_size + kv_last_page_len = block_size kv_last_page_lens.append(kv_last_page_len) - + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - - workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) - wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") - wrapper.begin_forward( - kv_indptr, - kv_indices, - kv_last_page_lens, - num_query_heads, - num_kv_heads, - head_size, - block_size, - "NONE", - data_type=dtype - ) - - output = wrapper.forward( - query, key_value_cache, logits_soft_cap=soft_cap - ) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=[1] * num_seqs, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap - ) + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.\ + BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") + wrapper.begin_forward(kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + data_type=dtype) + + output = wrapper.forward(query, key_value_cache, logits_soft_cap=soft_cap) + + ref_output = ref_paged_attn(query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap) assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" - - @pytest.mark.parametrize("seq_lens", [[(1, 1328), (5, 18), (129, 463)]]) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("soft_cap", [None, 10.0, 20.0, 30.0]) +@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @torch.inference_mode -def test_flashinfer_prefill_with_paged_kv( - seq_lens: List[Tuple[int, int]], - num_heads: Tuple[int, int], - head_size: int, - dtype: torch.dtype, - block_size: int, - soft_cap: Optional[float] -) -> None: +def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], + num_heads: Tuple[int, int], + head_size: int, dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float]) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) num_seqs = len(seq_lens) @@ -185,9 +175,12 @@ def test_flashinfer_prefill_with_paged_kv( num_query_heads, head_size, dtype=dtype) - key_value_cache = torch.randn( - NUM_BLOCKS, 2, block_size, num_kv_heads, head_size, dtype=dtype - ) + key_value_cache = torch.randn(NUM_BLOCKS, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype) key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) @@ -214,20 +207,18 @@ def test_flashinfer_prefill_with_paged_kv( kv_indptr.append(kv_indptr[-1] + num_blocks) kv_last_page_len = seq_len % block_size if kv_last_page_len == 0: - kv_last_page_len = block_size + kv_last_page_len = block_size kv_last_page_lens.append(kv_last_page_len) qo_indptr.append(qo_indptr[-1] + query_lens[i]) - - qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32) + + qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32) kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) kv_indices = torch.tensor(kv_indices, dtype=torch.int32) - kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) - + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( - workspace_buffer, "NHD" - ) + workspace_buffer, "NHD") wrapper.begin_forward( qo_indptr, kv_indptr, @@ -238,21 +229,20 @@ def test_flashinfer_prefill_with_paged_kv( head_size, block_size, ) - + output = wrapper.forward( - query, key_value_cache, - logits_soft_cap=soft_cap, + query, + key_value_cache, + logits_soft_cap=soft_cap, ) - ref_output = ref_paged_attn( - query=query, - key_cache=key_cache, - value_cache=value_cache, - query_lens=query_lens, - kv_lens=kv_lens, - block_tables=block_tables, - scale=scale, - soft_cap=soft_cap - ) + ref_output = ref_paged_attn(query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap) assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" From d248cef0905c125ee2bdcbf18fb507e1736e03ba Mon Sep 17 00:00:00 2001 From: Lily Liu Date: Tue, 2 Jul 2024 14:16:07 -0700 Subject: [PATCH 08/13] Update .buildkite/test-pipeline.yaml Co-authored-by: Simon Mo --- .buildkite/test-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 10e701ca3b06d..187530eb583ad 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -107,7 +107,7 @@ steps: - label: Kernels Test %N #mirror_hardwares: [amd] - command: + commands: - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.7/flashinfer-0.0.7+cu121torch2.3-cp310-cp310-linux_x86_64.whl - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT parallelism: 4 From 09479614f9bba917c3d71a5c02a4b63e8728f24b Mon Sep 17 00:00:00 2001 From: lilyliu Date: Tue, 2 Jul 2024 22:08:50 -0700 Subject: [PATCH 09/13] fix --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 70aeb12952a95..1346ec1c29599 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -660,7 +660,7 @@ def _prepare_model_input_tensors( device=self.device) logits_soft_cap = getattr(self.model_config.hf_config, - 'final_logit_softcapping', None) + 'attn_logit_softcapping', None) if logits_soft_cap is not None and self.attn_backend.get_name( ) != "flashinfer": raise ValueError("Please use Flashinfer backend for models with" From 0cc481c9ff9fd7a6c4b7fbb69e23768efca88c1c Mon Sep 17 00:00:00 2001 From: lilyliu Date: Wed, 3 Jul 2024 20:55:59 -0700 Subject: [PATCH 10/13] remove warning --- vllm/attention/selector.py | 4 ++-- vllm/model_executor/models/gemma2.py | 7 ------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 32d2fdb3fafbf..ae63eb1d48f8d 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -77,9 +77,9 @@ def get_attn_backend( return IpexAttnBackend elif backend == _Backend.FLASHINFER: logger.info("Using Flashinfer backend.") - logger.warning(("Flashinfer will be stuck on llma-2-7b," + logger.warning(("Flashinfer will be stuck on llama-2-7b," " please avoid using Flashinfer as the " - "backend when running on llma-2-7b.")) + "backend when running on llama-2-7b.")) from vllm.attention.backends.flashinfer import FlashInferBackend return FlashInferBackend elif backend == _Backend.PALLAS: diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 4e35a9ec34069..54f27007e55a6 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -38,7 +38,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput -from vllm.utils import print_warning_once from .interfaces import SupportsLoRA @@ -137,12 +136,6 @@ def __init__(self, dtype=torch.get_default_dtype(), ) - if self.config.attn_logit_softcapping is not None: - print_warning_once( - "Gemma 2 normally uses attention logit soft-capping; " - "soft-capping is currently incompatible with the flash " - "attention kernels, so vLLM removes it to enable speed and " - "efficiency gains of flash attention.") # FIXME(woosuk): While Gemma 2 uses sliding window attention for every # odd layer, vLLM currently ignores it and uses global attention for # all layers. From cf0268286ea063ceaa6d1111aabad020b657e6cb Mon Sep 17 00:00:00 2001 From: lilyliu Date: Wed, 3 Jul 2024 21:06:15 -0700 Subject: [PATCH 11/13] fix merge bug --- vllm/worker/model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7d4137d982d9a..7ffbf29e3e0c4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -15,7 +15,7 @@ from flashinfer import BatchDecodeWithPagedKVCacheWrapper from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - FLASHINFER_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: BatchDecodeWithPagedKVCacheWrapper = None CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None @@ -1205,7 +1205,7 @@ def execute_model( self.flashinfer_prefill_wrapper if model_input.attn_metadata.use_cuda_graph: batch_size = model_input.input_tokens.shape[0] - model_input.attn_metadata.decode_wrapper = self.graph_runners[ + model_input.attn_metadata.decode_wrapper = self.graph_runners[model_input.virtual_engine][ batch_size].flashinfer_decode_wrapper else: model_input.attn_metadata.decode_wrapper = \ From 8e4d1d33ca0616ea7c4ed1ac19684a92bd275067 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Wed, 3 Jul 2024 21:07:51 -0700 Subject: [PATCH 12/13] format --- vllm/model_executor/models/gemma2.py | 1 - vllm/worker/model_runner.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index f61f9c72bf081..8386084c2b3f8 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -39,7 +39,6 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors, SamplerOutput - from .interfaces import SupportsLoRA diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7ffbf29e3e0c4..2ae5263baa18c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1205,8 +1205,9 @@ def execute_model( self.flashinfer_prefill_wrapper if model_input.attn_metadata.use_cuda_graph: batch_size = model_input.input_tokens.shape[0] - model_input.attn_metadata.decode_wrapper = self.graph_runners[model_input.virtual_engine][ - batch_size].flashinfer_decode_wrapper + model_input.attn_metadata.decode_wrapper = self.graph_runners[ + model_input. + virtual_engine][batch_size].flashinfer_decode_wrapper else: model_input.attn_metadata.decode_wrapper = \ self.flashinfer_decode_wrapper From 52570158cb762026d18b2d387a6b3af027ec6346 Mon Sep 17 00:00:00 2001 From: lilyliu Date: Thu, 4 Jul 2024 10:29:57 -0700 Subject: [PATCH 13/13] remove model test since unit tests already convered and unblock release --- tests/models/test_gemma2.py | 41 ------------------------------------- 1 file changed, 41 deletions(-) delete mode 100644 tests/models/test_gemma2.py diff --git a/tests/models/test_gemma2.py b/tests/models/test_gemma2.py deleted file mode 100644 index b2210002f6e0c..0000000000000 --- a/tests/models/test_gemma2.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Compare the outputs of HF and vLLM for Gemma2 models using greedy sampling. - -Run `pytest tests/models/test_gemma2.py`. -""" -import os - -import pytest - -from .utils import check_logprobs_close - -MODELS = ["google/gemma-2-9b"] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [64]) -@pytest.mark.parametrize("num_logprobs", [5]) -def test_models( - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - num_logprobs: int, -) -> None: - os.environ['VLLM_ATTENTION_BACKEND'] = "FLASHINFER" - # TODO(sang): Sliding window should be tested separately. - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, max_tokens, num_logprobs) - - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, max_tokens, num_logprobs) - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - )