From 28aad31a089e159cf71caf86d72c5610d80cec08 Mon Sep 17 00:00:00 2001 From: derekk-nm Date: Mon, 13 May 2024 13:50:59 +0000 Subject: [PATCH 1/5] Basic server correctness test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introducing an end-to-end test case that verifies basic correctness of the vllm server by comparing the tokens output by the vllm OpenAI server with tokens generated by the HuggingFace model created with AutoModelForCausalLM.from_pretrained(). Updates HfRunner() to accept a HuggingFace access token to be able to retrieve models that are restricted access The new HfRunnerNM.generate_greedy_logprobs_nm_use_tokens() allows us to compare the HuggingFace generated results (which reports logprobs with token ids) with that from the vllm OpenAI Server (which reports logprobs with token text). This included a new _decode_token_by_position_index() method to properly calculate the token string by using a lookback on the generated tokens list. Enhances the output of the check_logprobs_close() function to provide more details about the failing tokens. Adds the test to the appropriate skip-*.txt files so that this long running test won’t get automatically run during automatic dev push workflows. --- neuralmagic/tests/skip-almost-all.txt | 1 + neuralmagic/tests/skip-for-remote-push.txt | 1 + requirements-dev.txt | 1 + tests/basic_correctness/__init__.py | 0 .../test_basic_server_correctness.py | 212 ++++++++++++++++++ tests/conftest.py | 132 +++++++++-- tests/models/compare_utils.py | 12 +- 7 files changed, 339 insertions(+), 20 deletions(-) create mode 100644 tests/basic_correctness/__init__.py create mode 100644 tests/basic_correctness/test_basic_server_correctness.py diff --git a/neuralmagic/tests/skip-almost-all.txt b/neuralmagic/tests/skip-almost-all.txt index 99a541c7e1628..e2e7cf036f569 100644 --- a/neuralmagic/tests/skip-almost-all.txt +++ b/neuralmagic/tests/skip-almost-all.txt @@ -84,4 +84,5 @@ tests/engine/test_detokenization.py tests/engine/test_computed_prefix_blocks.py tests/basic_correctness/test_chunked_prefill.py tests/basic_correctness/test_basic_correctness.py +tests/basic_correctness/test_basic_server_correctness.py tests/test_cache_block_hashing.py diff --git a/neuralmagic/tests/skip-for-remote-push.txt b/neuralmagic/tests/skip-for-remote-push.txt index 9907f8b51c66e..c52db98f46a1d 100644 --- a/neuralmagic/tests/skip-for-remote-push.txt +++ b/neuralmagic/tests/skip-for-remote-push.txt @@ -49,4 +49,5 @@ tests/lora/test_lora.py tests/worker/test_model_runner.py tests/engine/test_detokenize.py tests/engine/test_computed_prefix_blocks.py +tests/basic_correctness/test_basic_server_correctness.py tests/accuracy/test_lm_eval_correctness.py diff --git a/requirements-dev.txt b/requirements-dev.txt index e0418fc7c250f..53de125bdc71e 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,6 +13,7 @@ types-requests==2.31.0.2 types-setuptools # testing +datasets pytest tensorizer==2.9.0 pytest-forked diff --git a/tests/basic_correctness/__init__.py b/tests/basic_correctness/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/basic_correctness/test_basic_server_correctness.py b/tests/basic_correctness/test_basic_server_correctness.py new file mode 100644 index 0000000000000..3d0abd0c04387 --- /dev/null +++ b/tests/basic_correctness/test_basic_server_correctness.py @@ -0,0 +1,212 @@ +import asyncio +from os import getenv +from typing import Dict, List, Type + +import openai +import pytest +import torch +from datasets import load_dataset +from openai import AsyncOpenAI +from transformers import AutoTokenizer + +from tests.conftest import HfRunnerNM +from tests.models.compare_utils import check_logprobs_close +from tests.utils.logging import make_logger +from tests.utils.server import ServerContext + + +@pytest.fixture(scope="session") +def client(): + client = openai.AsyncOpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + yield client + + +@pytest.fixture +def hf_runner_nm() -> Type[HfRunnerNM]: + return HfRunnerNM + + +async def my_chat( + client, + model: str, + messages: List[Dict], + max_tokens: int, + temperature: float, + num_logprobs: int, +): + """ submit a single prompt chat and collect results. """ + return await client.chat.completions.create(model=model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + logprobs=True, + top_logprobs=num_logprobs) + + +@pytest.mark.parametrize( + "model, max_model_len, sparsity", + [ + ("mistralai/Mistral-7B-Instruct-v0.2", 4096, None), + # pytest.param("mistralai/Mixtral-8x7B-Instruct-v0.1", 4096, None, + # marks=pytest.mark.skip( + # "skipped because the HFRunner " + # "will need the 'optimum' package")), + # ("neuralmagic/zephyr-7b-beta-marlin", 4096, None), + # ("neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50", + # 4096, "sparse_w16a16"), + # ("NousResearch/Llama-2-7b-chat-hf", 4096, None), + # pytest.param( + # "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", + # None, + # None, + # marks=pytest.mark.skip( + # "skipped because the HFRunner will need the " + # "'optimum' package") + # ), + # ("neuralmagic/Llama-2-7b-pruned70-retrained-ultrachat", + # 4096, "sparse_w16a16"), + # ("HuggingFaceH4/zephyr-7b-gemma-v0.1", 4096, None), + # ("Qwen/Qwen1.5-7B-Chat", 4096, None), + # ("microsoft/phi-2", 2048, None), + # pytest.param( + # "neuralmagic/phi-2-super-marlin", + # 2048, + # None, + # marks=pytest.mark.skip( + # "skipped because the HFRunner will need the " + # "'optimum' package") + # ), + # ("neuralmagic/phi-2-pruned50", 2048, "sparse_w16a16"), + # pytest.param( + # "Qwen/Qwen1.5-MoE-A2.7B-Chat", + # 4096, + # None, + # marks=pytest.mark.skip( + # "ValueError: The checkpoint you are trying to load has model" + # "type `qwen2_moe` but Transformers does not recognize this " + # "architecture. This could be because of an issue with the " + # "checkpoint, or because your version of Transformers is " + # "out of date.")), + # pytest.param("casperhansen/gemma-7b-it-awq", 4096, None, + # marks=pytest.mark.skip( + # "skipped because the HFRunner will need the " + # "autoawq library")), + # pytest.param( + # "TheBloke/Llama-2-7B-Chat-GPTQ", + # 4096, + # None, + # marks=pytest.mark.skip( + # "skipped because the HFRunner will need the " + # "'optimum' package") + # ), + ]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [3]) +@pytest.mark.parametrize("tensor_parallel_size", [None, 2]) +# note: repeating the test for 2 values of tensor_parallel_size +# increases the overall execution time by unnecessarily +# collecting the HuggingFace runner data twice. +# Consider refactoring to eliminate that repeat. +def test_models_on_server( + hf_runner_nm: HfRunnerNM, + client: AsyncOpenAI, + model: str, + max_model_len: int, + sparsity: str, + tensor_parallel_size: int, + max_tokens: int, + num_logprobs: int, +) -> None: + """ + This test compares the output of the vllm OpenAI server against that of + a HuggingFace transformer. We expect them to be fairly close. "Close" + is measured by checking that the top 3 logprobs for each token includes + the token of the other inference tool. + + :param hf_runner_nm: fixture for the HfRunnerNM + :param client: fixture with an openai.AsyncOpenAI client + :param model: The Hugginface id for a model to test with + :param max_model_len: passed to the vllm Server's --max-model-len option + :param sparsity: passed to the vllm Server's --sparsity option + :param tensor_parallel_size: passed to the vllm Server's + --tensor_parallel_size option + :param max_tokens: the total number of tokens to consider for closeness + :param num_logprobs: the total number of logprobs included when + calculating closeness + """ + logger = make_logger("vllm_test") + # check that the requested gpu count is available in the test env + gpu_count = torch.cuda.device_count() + if tensor_parallel_size and gpu_count < tensor_parallel_size: + pytest.skip(f"gpu count {gpu_count} is insufficient for " + f"tensor_parallel_size = {tensor_parallel_size}") + + hf_token = getenv("HF_TOKEN", None) + logger.info("loading chat prompts for testing.") + ds = load_dataset("nm-testing/qa-chat-prompts", split="train_sft") + num_chat_turns = 3 + messages_list = [row["messages"][:num_chat_turns] for row in ds] + tokenizer = AutoTokenizer.from_pretrained(model) + chat_prompts = [ + tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + for messages in messages_list + ] + + logger.info("generating chat responses from HuggingFace runner.") + hf_model = hf_runner_nm(model, access_token=hf_token) + hf_outputs = hf_model.generate_greedy_logprobs_nm_use_tokens( + chat_prompts, max_tokens, num_logprobs) + + del hf_model + + logger.info("generating chat responses from vllm server.") + api_server_args = { + "--model": model, + "--max-model-len": max_model_len, + "--disable-log-requests": None, + } + if sparsity: + api_server_args["--sparsity"] = sparsity + if tensor_parallel_size: + api_server_args["--tensor-parallel-size"] = tensor_parallel_size + + # some devices will require a different `dtype` + device_capability = torch.cuda.get_device_capability() + if device_capability[0] < 8: + api_server_args["--dtype"] = "half" + + asyncio_event_loop = asyncio.get_event_loop() + temperature = 0.0 + with ServerContext(api_server_args, logger=logger) as _: + # submit an asynchronous request to the server for each prompt + chats = [ + my_chat(client, model, messages, max_tokens, temperature, + num_logprobs) + for messages in [query for query in messages_list] + ] + # await for all the requests to return, and gather their results + # in one place + results = asyncio_event_loop.run_until_complete(asyncio.gather(*chats)) + + logger.info("preparing results from vllm server requests to include " + "tokens and logprobs.") + vllm_outputs = list() + for task_result in results: + for req_output in task_result.choices: + output_str = req_output.message.content + output_tokens = req_output.logprobs.model_extra["tokens"] + output_logprobs = req_output.logprobs.model_extra["top_logprobs"] + vllm_outputs.append((output_tokens, output_str, output_logprobs)) + + logger.info("comparing HuggingFace and vllm Server chat responses") + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf_model", + name_1="vllm_model", + ) diff --git a/tests/conftest.py b/tests/conftest.py index 3b557df55c8a0..1e32e2164e2ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import gc import logging import os -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import pytest import torch @@ -143,6 +143,7 @@ def __init__( model_name: str, tokenizer_name: Optional[str] = None, dtype: str = "half", + access_token: Optional[str] = None, ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -152,7 +153,7 @@ def __init__( model_name, torch_dtype=torch_dtype, trust_remote_code=True, - ).cuda() + token=access_token).cuda() self.processor = None else: self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( @@ -287,12 +288,31 @@ def hf_runner(): # UPSTREAM SYNC: needed for nm-automation class HfRunnerNM(HfRunner): + def _logprobs_from_generated_hidden_states(self, output): + """ + generates a list of logprobs from the output of self.model.generate() + """ + seq_logprobs = [] + for _, hidden_states in enumerate(output.hidden_states): + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if self.model.get_output_embeddings().bias is not None: + logits += self.model.get_output_embeddings().bias.unsqueeze(0) + logprobs = torch.nn.functional.log_softmax(logits, + dim=-1, + dtype=torch.float32) + seq_logprobs.append(logprobs) + return seq_logprobs + def generate_greedy_logprobs_nm( self, prompts: List[str], max_tokens: int, num_logprobs: int, - ) -> List[Tuple[List[int], str]]: + ) -> List[Tuple[List[int], str, List[Dict]]]: all_logprobs = [] all_output_ids = [] all_output_strs = [] @@ -308,20 +328,7 @@ def generate_greedy_logprobs_nm( return_dict_in_generate=True, ) - seq_logprobs = [] - for _, hidden_states in enumerate(output.hidden_states): - last_hidden_states = hidden_states[-1][0] - logits = torch.matmul( - last_hidden_states, - self.model.get_output_embeddings().weight.t(), - ) - if self.model.get_output_embeddings().bias is not None: - logits += self.model.get_output_embeddings( - ).bias.unsqueeze(0) - logprobs = torch.nn.functional.log_softmax(logits, - dim=-1, - dtype=torch.float32) - seq_logprobs.append(logprobs) + seq_logprobs = self._logprobs_from_generated_hidden_states(output) # convert to dict seq_logprobs_lst = [] @@ -348,6 +355,97 @@ def generate_greedy_logprobs_nm( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + def _decode_token_by_position_index( + self, + token_index: int, + token_ids: List[int], + clean_up_tokenization_spaces: bool = True) -> str: + """ + helper function to calculate a token string at the specified index + based on the previous tokens. + + :param token_index: position in the list of token ids where you would + like the token string + :param token_ids: the list of token ids + :param clean_up_tokenization_spaces: option to pass to + `tokenizer.decode()` + """ + lookback = 4 + prior_str = self.tokenizer.decode( + token_ids[token_index - lookback:token_index - 1], + clean_up_tokenization_spaces=clean_up_tokenization_spaces) + current_str = self.tokenizer.decode( + token_ids[token_index - lookback:token_index], + clean_up_tokenization_spaces=clean_up_tokenization_spaces) + token = current_str[-(len(current_str) - len(prior_str)):] + return token + + def generate_greedy_logprobs_nm_use_tokens( + self, + prompts: List[str], + max_tokens: int, + topk_logprobs_count: int, + ) -> List[Tuple[List[int], str, List[Dict]]]: + all_logprobs = [] + all_output_tokens = [] + all_output_strs = [] + + for prompt in prompts: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + output = self.model.generate( + input_ids.cuda(), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + seq_logprobs = self._logprobs_from_generated_hidden_states(output) + + # convert sequence of logprobs to a dict keyed on the selected token + seq_ids = output.sequences[0] + input_len = input_ids.shape[1] + seq_logprobs_lst = list() + output_tokens = list() + + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(topk_logprobs_count) + + tok_logprobs_dct = {} + # add 1 to the index here to be in sync with the skipped prompt + # in the logprobs + indexed_seq = seq_ids[:input_len + tok_idx + 1].tolist() + token_str = self._decode_token_by_position_index( + input_len + tok_idx + 1, + indexed_seq, + clean_up_tokenization_spaces=False) + output_tokens.append(token_str) + for alt_token_id, logprob in zip(topk.indices[0], + topk.values[0]): + # replace the token_str at the tok_idx with alt_token_id. + indexed_seq[-1] = alt_token_id + # then run decode again + logprob_str = self._decode_token_by_position_index( + input_len + tok_idx + 1, + indexed_seq, + clean_up_tokenization_spaces=False) + tok_logprobs_dct[logprob_str] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + + all_output_tokens.append(output_tokens) + all_output_strs.append("".join(output_tokens)) + + outputs = zip(all_output_tokens, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + def __del__(self): del self.model cleanup() diff --git a/tests/models/compare_utils.py b/tests/models/compare_utils.py index 44319b6ca45ff..051cbf1547b21 100644 --- a/tests/models/compare_utils.py +++ b/tests/models/compare_utils.py @@ -1,4 +1,4 @@ -"""Compare the logprobs of two sequences generated by different models, +"""Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. """ @@ -19,13 +19,19 @@ def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1): if output_id_0 != output_id_1: # Each predicted token must be in top N logprobs of the other assert output_id_0 in logprobs_1[idx], ( - f"Test{prompt_idx}:" + f"{name_0} token '{output_id_0}' not in " + f"[{list(logprobs_1[idx].keys())}]" + f"\nprompt index {prompt_idx}, token index {idx}:" f"\n{name_0}:\t{output_str_0!r}" f"\n{name_1}:\t{output_str_1!r}") assert output_id_1 in logprobs_0[idx], ( - f"Test{prompt_idx}:" + f"{name_1} token '{output_id_1}' not in " + f"[{list(logprobs_0[idx].keys())}]" + f"\nprompt index {prompt_idx}, token index {idx}:" f"\n{name_0}:\t{output_str_0!r}" f"\n{name_1}:\t{output_str_1!r}") # Break out since sequences will now diverge. + # as long as we got this far with the output tokens being the + # same, or close, the responses are close enough break From c5fecaa1c1a5387bb2d085f54a0b2a8343a9512a Mon Sep 17 00:00:00 2001 From: derekk-nm Date: Wed, 22 May 2024 12:30:08 +0000 Subject: [PATCH 2/5] Include other models Test other models. Skip execution if the model requires a GPU device capability greater than that available on the current device (reusing approach from test_gptq_marlin.py). adds a hack to ignore special tokens after decode of HuggingFace response so that we can fairly compare with vllm server response. --- .../test_basic_server_correctness.py | 101 +++++++++++------- tests/conftest.py | 20 +++- 2 files changed, 77 insertions(+), 44 deletions(-) diff --git a/tests/basic_correctness/test_basic_server_correctness.py b/tests/basic_correctness/test_basic_server_correctness.py index 3d0abd0c04387..39eec77f74fd3 100644 --- a/tests/basic_correctness/test_basic_server_correctness.py +++ b/tests/basic_correctness/test_basic_server_correctness.py @@ -13,6 +13,7 @@ from tests.models.compare_utils import check_logprobs_close from tests.utils.logging import make_logger from tests.utils.server import ServerContext +from vllm.model_executor.layers.quantization import get_quantization_config @pytest.fixture(scope="session") @@ -47,60 +48,68 @@ async def my_chat( @pytest.mark.parametrize( - "model, max_model_len, sparsity", + "model, max_model_len, sparsity, gptq_config", [ - ("mistralai/Mistral-7B-Instruct-v0.2", 4096, None), - # pytest.param("mistralai/Mixtral-8x7B-Instruct-v0.1", 4096, None, - # marks=pytest.mark.skip( - # "skipped because the HFRunner " - # "will need the 'optimum' package")), - # ("neuralmagic/zephyr-7b-beta-marlin", 4096, None), - # ("neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50", - # 4096, "sparse_w16a16"), - # ("NousResearch/Llama-2-7b-chat-hf", 4096, None), + ("mistralai/Mistral-7B-Instruct-v0.2", 4096, None, None), + # pytest.param("mistralai/Mixtral-8x7B-Instruct-v0.1", 4096, None, None, + # # marks=pytest.mark.skip( + # # "skipped because the HFRunner " + # # "will need the 'optimum' package") + # ), + # pytest.param("neuralmagic/zephyr-7b-beta-marlin", 4096, None, None, + # # marks=pytest.mark.skip( + # # "skipped because the HFRunner " + # # "will need the 'optimum' package") + # ), + ("neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50", 4096, + "sparse_w16a16", None), + ("NousResearch/Llama-2-7b-chat-hf", 4096, None, None), # pytest.param( # "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", # None, # None, - # marks=pytest.mark.skip( - # "skipped because the HFRunner will need the " - # "'optimum' package") - # ), - # ("neuralmagic/Llama-2-7b-pruned70-retrained-ultrachat", - # 4096, "sparse_w16a16"), - # ("HuggingFaceH4/zephyr-7b-gemma-v0.1", 4096, None), - # ("Qwen/Qwen1.5-7B-Chat", 4096, None), - # ("microsoft/phi-2", 2048, None), - # pytest.param( - # "neuralmagic/phi-2-super-marlin", - # 2048, # None, - # marks=pytest.mark.skip( - # "skipped because the HFRunner will need the " - # "'optimum' package") + # # marks=pytest.mark.skip( + # # "skipped because the HFRunner will need the " + # # "'optimum' package") # ), - # ("neuralmagic/phi-2-pruned50", 2048, "sparse_w16a16"), + ("neuralmagic/Llama-2-7b-pruned70-retrained-ultrachat", 4096, + "sparse_w16a16", None), + ("HuggingFaceH4/zephyr-7b-gemma-v0.1", 4096, None, None), + ("Qwen/Qwen1.5-7B-Chat", 4096, None, None), + ("microsoft/phi-2", 2048, None, None), + pytest.param( + "neuralmagic/phi-2-super-marlin", + 2048, + None, + None, + marks=pytest.mark.skip( + "https://app.asana.com/0/1206976017967941/1207360919122996")), + pytest.param( + "neuralmagic/phi-2-pruned50", + 2048, + "sparse_w16a16", + None, + marks=pytest.mark.skip( + "https://app.asana.com/0/1206976017967941/1207360919122996")), # pytest.param( # "Qwen/Qwen1.5-MoE-A2.7B-Chat", # 4096, # None, + # None, # marks=pytest.mark.skip( - # "ValueError: The checkpoint you are trying to load has model" - # "type `qwen2_moe` but Transformers does not recognize this " - # "architecture. This could be because of an issue with the " - # "checkpoint, or because your version of Transformers is " - # "out of date.")), - # pytest.param("casperhansen/gemma-7b-it-awq", 4096, None, - # marks=pytest.mark.skip( - # "skipped because the HFRunner will need the " - # "autoawq library")), + # "CUDA out of memory. Tried to allocate 20.00 MiB. GPU") + # ), + pytest.param("casperhansen/gemma-7b-it-awq", 4096, None, + "gptq_marlin"), # pytest.param( # "TheBloke/Llama-2-7B-Chat-GPTQ", # 4096, # None, - # marks=pytest.mark.skip( - # "skipped because the HFRunner will need the " - # "'optimum' package") + # None, + # # marks=pytest.mark.skip( + # # "skipped because the HFRunner will need the " + # # "'optimum' package") # ), ]) @pytest.mark.parametrize("max_tokens", [32]) @@ -116,6 +125,7 @@ def test_models_on_server( model: str, max_model_len: int, sparsity: str, + gptq_config: str, tensor_parallel_size: int, max_tokens: int, num_logprobs: int, @@ -131,6 +141,8 @@ def test_models_on_server( :param model: The Hugginface id for a model to test with :param max_model_len: passed to the vllm Server's --max-model-len option :param sparsity: passed to the vllm Server's --sparsity option + :param gptq_config: quantization method id for this model. default None + means quantization isn't involved. :param tensor_parallel_size: passed to the vllm Server's --tensor_parallel_size option :param max_tokens: the total number of tokens to consider for closeness @@ -144,6 +156,16 @@ def test_models_on_server( pytest.skip(f"gpu count {gpu_count} is insufficient for " f"tensor_parallel_size = {tensor_parallel_size}") + # skip this model if the current device does not have the required + # gpu capability. + device_capability = torch.cuda.get_device_capability() + capability = device_capability[0] * 10 + device_capability[1] + if gptq_config and ( + capability < + get_quantization_config(gptq_config).get_min_capability()): + pytest.skip("insufficient system GPU device capability " + f"({capability}) for this model") + hf_token = getenv("HF_TOKEN", None) logger.info("loading chat prompts for testing.") ds = load_dataset("nm-testing/qa-chat-prompts", split="train_sft") @@ -160,7 +182,7 @@ def test_models_on_server( logger.info("generating chat responses from HuggingFace runner.") hf_model = hf_runner_nm(model, access_token=hf_token) hf_outputs = hf_model.generate_greedy_logprobs_nm_use_tokens( - chat_prompts, max_tokens, num_logprobs) + chat_prompts, max_tokens, num_logprobs, ignore_special_tokens=True) del hf_model @@ -176,7 +198,6 @@ def test_models_on_server( api_server_args["--tensor-parallel-size"] = tensor_parallel_size # some devices will require a different `dtype` - device_capability = torch.cuda.get_device_capability() if device_capability[0] < 8: api_server_args["--dtype"] = "half" diff --git a/tests/conftest.py b/tests/conftest.py index 1e32e2164e2ce..6c264b09ecb75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -153,13 +153,15 @@ def __init__( model_name, torch_dtype=torch_dtype, trust_remote_code=True, - token=access_token).cuda() + token=access_token, + ).cuda() self.processor = None else: self.model = _VISION_LANGUAGE_MODELS[model_name].from_pretrained( model_name, torch_dtype=torch_dtype, trust_remote_code=True, + token=access_token, ).cuda() self.processor = AutoProcessor.from_pretrained( model_name, @@ -355,11 +357,14 @@ def generate_greedy_logprobs_nm( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>", "�", ""] + def _decode_token_by_position_index( self, token_index: int, token_ids: List[int], - clean_up_tokenization_spaces: bool = True) -> str: + clean_up_tokenization_spaces: bool = True, + ignore_special_tokens: bool = False) -> str: """ helper function to calculate a token string at the specified index based on the previous tokens. @@ -369,6 +374,8 @@ def _decode_token_by_position_index( :param token_ids: the list of token ids :param clean_up_tokenization_spaces: option to pass to `tokenizer.decode()` + :param ignore_special_tokens: converts hard coded special tokens to + an empty string """ lookback = 4 prior_str = self.tokenizer.decode( @@ -378,6 +385,8 @@ def _decode_token_by_position_index( token_ids[token_index - lookback:token_index], clean_up_tokenization_spaces=clean_up_tokenization_spaces) token = current_str[-(len(current_str) - len(prior_str)):] + if ignore_special_tokens and token in self.SPECIAL_TOKENS: + token = "" return token def generate_greedy_logprobs_nm_use_tokens( @@ -385,6 +394,7 @@ def generate_greedy_logprobs_nm_use_tokens( prompts: List[str], max_tokens: int, topk_logprobs_count: int, + ignore_special_tokens: bool = False ) -> List[Tuple[List[int], str, List[Dict]]]: all_logprobs = [] all_output_tokens = [] @@ -422,7 +432,8 @@ def generate_greedy_logprobs_nm_use_tokens( token_str = self._decode_token_by_position_index( input_len + tok_idx + 1, indexed_seq, - clean_up_tokenization_spaces=False) + clean_up_tokenization_spaces=False, + ignore_special_tokens=True) output_tokens.append(token_str) for alt_token_id, logprob in zip(topk.indices[0], topk.values[0]): @@ -432,7 +443,8 @@ def generate_greedy_logprobs_nm_use_tokens( logprob_str = self._decode_token_by_position_index( input_len + tok_idx + 1, indexed_seq, - clean_up_tokenization_spaces=False) + clean_up_tokenization_spaces=False, + ignore_special_tokens=True) tok_logprobs_dct[logprob_str] = logprob.item() seq_logprobs_lst.append(tok_logprobs_dct) From 67acb7f8a1c0cfb024c4982a78cc55318ece49ae Mon Sep 17 00:00:00 2001 From: derekk-nm Date: Tue, 28 May 2024 15:21:06 +0000 Subject: [PATCH 3/5] Skip microsoft/phi-2 this model fails the test with a specific prompt. to be addressed later. --- tests/basic_correctness/test_basic_server_correctness.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/basic_correctness/test_basic_server_correctness.py b/tests/basic_correctness/test_basic_server_correctness.py index 39eec77f74fd3..2d05fe045e6a2 100644 --- a/tests/basic_correctness/test_basic_server_correctness.py +++ b/tests/basic_correctness/test_basic_server_correctness.py @@ -77,7 +77,13 @@ async def my_chat( "sparse_w16a16", None), ("HuggingFaceH4/zephyr-7b-gemma-v0.1", 4096, None, None), ("Qwen/Qwen1.5-7B-Chat", 4096, None, None), - ("microsoft/phi-2", 2048, None, None), + pytest.param( + "microsoft/phi-2", + 2048, + None, + None, + marks=pytest.mark.skip( + "https://app.asana.com/0/1206976017967941/1207409474409275")), pytest.param( "neuralmagic/phi-2-super-marlin", 2048, From 01da8d4f0e825c17f0e1445b065286c40b59035c Mon Sep 17 00:00:00 2001 From: derekk-nm Date: Tue, 28 May 2024 22:30:39 +0000 Subject: [PATCH 4/5] removed commented models entries have been moved to the bug report, where failing models will be tracked. removed some additional models that do not work in the build/test env (until a resolution is found) expanded doc on the test case added a README for the *_skip.txt files. --- neuralmagic/tests/README.md | 29 ++++++ requirements-dev.txt | 1 + .../test_basic_server_correctness.py | 94 +++++-------------- 3 files changed, 51 insertions(+), 73 deletions(-) create mode 100644 neuralmagic/tests/README.md diff --git a/neuralmagic/tests/README.md b/neuralmagic/tests/README.md new file mode 100644 index 0000000000000..acd461e2883c9 --- /dev/null +++ b/neuralmagic/tests/README.md @@ -0,0 +1,29 @@ +**neuralmagic/tests** + +This directory contains a set of `*.txt` files that are used by the build and +test system to skip certain tests for certain build targets in order to +optimize the amount of test execution time required for different scenarios. + +Overall, there are four test 'cycles'/triggers in the build system: + +* remote-push — occurs on any branch & commit being pushed to the repo +* release — currently only triggered manually +* nightly/weekly — basically the same job, runs on a schedule (weekly runs on Sun, nightly runs other days) + +There is a list of test cases associated with each trigger, and a broadly encompassing one: +* skip-almost-all.txt (this is used for rapid GHA dev work to run fast) +* skip-for-remote-push.txt +* skip-for-release.txt +* skip-for-nightly.txt +* skip-for-weekly.txt + +Particularly long-running or less critical tests should not be run during +a remote push, but should probably be run against nightly/weekly builds +and a final release build. In such a scenario, to get your test to run for the +release and nightly/weekly triggers, and skip it for other triggers, add your +test (file) to the following skip lists: +* skip-almost-all.txt +* skip-for-remote-push.txt + +This will basically mean your test is only skipped during remote-push, +and will run for all other triggers. diff --git a/requirements-dev.txt b/requirements-dev.txt index 53de125bdc71e..6d67854aa2f94 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,6 +13,7 @@ types-requests==2.31.0.2 types-setuptools # testing +autoawq datasets pytest tensorizer==2.9.0 diff --git a/tests/basic_correctness/test_basic_server_correctness.py b/tests/basic_correctness/test_basic_server_correctness.py index 2d05fe045e6a2..7405c11b60b72 100644 --- a/tests/basic_correctness/test_basic_server_correctness.py +++ b/tests/basic_correctness/test_basic_server_correctness.py @@ -47,80 +47,21 @@ async def my_chat( top_logprobs=num_logprobs) -@pytest.mark.parametrize( - "model, max_model_len, sparsity, gptq_config", - [ - ("mistralai/Mistral-7B-Instruct-v0.2", 4096, None, None), - # pytest.param("mistralai/Mixtral-8x7B-Instruct-v0.1", 4096, None, None, - # # marks=pytest.mark.skip( - # # "skipped because the HFRunner " - # # "will need the 'optimum' package") - # ), - # pytest.param("neuralmagic/zephyr-7b-beta-marlin", 4096, None, None, - # # marks=pytest.mark.skip( - # # "skipped because the HFRunner " - # # "will need the 'optimum' package") - # ), - ("neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50", 4096, - "sparse_w16a16", None), - ("NousResearch/Llama-2-7b-chat-hf", 4096, None, None), - # pytest.param( - # "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", - # None, - # None, - # None, - # # marks=pytest.mark.skip( - # # "skipped because the HFRunner will need the " - # # "'optimum' package") - # ), - ("neuralmagic/Llama-2-7b-pruned70-retrained-ultrachat", 4096, - "sparse_w16a16", None), - ("HuggingFaceH4/zephyr-7b-gemma-v0.1", 4096, None, None), - ("Qwen/Qwen1.5-7B-Chat", 4096, None, None), - pytest.param( - "microsoft/phi-2", - 2048, - None, - None, - marks=pytest.mark.skip( - "https://app.asana.com/0/1206976017967941/1207409474409275")), - pytest.param( - "neuralmagic/phi-2-super-marlin", - 2048, - None, - None, - marks=pytest.mark.skip( - "https://app.asana.com/0/1206976017967941/1207360919122996")), - pytest.param( - "neuralmagic/phi-2-pruned50", - 2048, - "sparse_w16a16", - None, - marks=pytest.mark.skip( - "https://app.asana.com/0/1206976017967941/1207360919122996")), - # pytest.param( - # "Qwen/Qwen1.5-MoE-A2.7B-Chat", - # 4096, - # None, - # None, - # marks=pytest.mark.skip( - # "CUDA out of memory. Tried to allocate 20.00 MiB. GPU") - # ), - pytest.param("casperhansen/gemma-7b-it-awq", 4096, None, - "gptq_marlin"), - # pytest.param( - # "TheBloke/Llama-2-7B-Chat-GPTQ", - # 4096, - # None, - # None, - # # marks=pytest.mark.skip( - # # "skipped because the HFRunner will need the " - # # "'optimum' package") - # ), - ]) +@pytest.mark.parametrize("model, max_model_len, sparsity, gptq_config", [ + ("mistralai/Mistral-7B-Instruct-v0.2", 4096, None, None), + ("neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50", 4096, "sparse_w16a16", + None), + ("NousResearch/Llama-2-7b-chat-hf", 4096, None, None), + ("neuralmagic/Llama-2-7b-pruned70-retrained-ultrachat", 4096, + "sparse_w16a16", None), + ("Qwen/Qwen1.5-7B-Chat", 4096, None, None), + ("casperhansen/gemma-7b-it-awq", 4096, None, "gptq_marlin"), + ("mistralai/Mixtral-8x7B-Instruct-v0.1", 4096, None, None), + ("Qwen/Qwen1.5-MoE-A2.7B-Chat", 4096, None, None), +]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [3]) -@pytest.mark.parametrize("tensor_parallel_size", [None, 2]) +@pytest.mark.parametrize("tensor_parallel_size", [None]) # note: repeating the test for 2 values of tensor_parallel_size # increases the overall execution time by unnecessarily # collecting the HuggingFace runner data twice. @@ -140,7 +81,14 @@ def test_models_on_server( This test compares the output of the vllm OpenAI server against that of a HuggingFace transformer. We expect them to be fairly close. "Close" is measured by checking that the top 3 logprobs for each token includes - the token of the other inference tool. + the token of the other inference tool. The first time that there is no + exact match, as long as there is a match to one of the top `num_logprobs` + logprobs, the test will not proceed further, but will pass. + + Parameters to the test identify a model to test, and key arguments + required for that model (see the `max_model_len`, `sparsity` and + `gptq_config` params below). The additional parametrizations expand test + coverage across the functional space of the server. :param hf_runner_nm: fixture for the HfRunnerNM :param client: fixture with an openai.AsyncOpenAI client From ecd752e99102c9f6378a1232b9706f04c40087e3 Mon Sep 17 00:00:00 2001 From: derekk-nm Date: Wed, 29 May 2024 10:39:27 +0000 Subject: [PATCH 5/5] skip on push adding tests/basic_correctness/test_basic_server_correctness.py to skip-for-remote-push-tmp.txt --- neuralmagic/tests/skip-for-remote-push-tmp.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/neuralmagic/tests/skip-for-remote-push-tmp.txt b/neuralmagic/tests/skip-for-remote-push-tmp.txt index cc601630521e3..f0e5a8f6ab575 100644 --- a/neuralmagic/tests/skip-for-remote-push-tmp.txt +++ b/neuralmagic/tests/skip-for-remote-push-tmp.txt @@ -99,6 +99,7 @@ tests/engine/output_processor/test_multi_step.py tests/engine/test_computed_prefix_blocks.py tests/basic_correctness/test_chunked_prefill.py tests/basic_correctness/test_preemption.py +tests/basic_correctness/test_basic_server_correctness.py tests/test_cache_block_hashing.py tests/test_logger.py tests/test_regression.py