diff --git a/neuralmagic/tests/README.md b/neuralmagic/tests/README.md new file mode 100644 index 0000000000000..acd461e2883c9 --- /dev/null +++ b/neuralmagic/tests/README.md @@ -0,0 +1,29 @@ +**neuralmagic/tests** + +This directory contains a set of `*.txt` files that are used by the build and +test system to skip certain tests for certain build targets in order to +optimize the amount of test execution time required for different scenarios. + +Overall, there are four test 'cycles'/triggers in the build system: + +* remote-push — occurs on any branch & commit being pushed to the repo +* release — currently only triggered manually +* nightly/weekly — basically the same job, runs on a schedule (weekly runs on Sun, nightly runs other days) + +There is a list of test cases associated with each trigger, and a broadly encompassing one: +* skip-almost-all.txt (this is used for rapid GHA dev work to run fast) +* skip-for-remote-push.txt +* skip-for-release.txt +* skip-for-nightly.txt +* skip-for-weekly.txt + +Particularly long-running or less critical tests should not be run during +a remote push, but should probably be run against nightly/weekly builds +and a final release build. In such a scenario, to get your test to run for the +release and nightly/weekly triggers, and skip it for other triggers, add your +test (file) to the following skip lists: +* skip-almost-all.txt +* skip-for-remote-push.txt + +This will basically mean your test is only skipped during remote-push, +and will run for all other triggers. diff --git a/neuralmagic/tests/skip-almost-all.txt b/neuralmagic/tests/skip-almost-all.txt index 99a541c7e1628..e2e7cf036f569 100644 --- a/neuralmagic/tests/skip-almost-all.txt +++ b/neuralmagic/tests/skip-almost-all.txt @@ -84,4 +84,5 @@ tests/engine/test_detokenization.py tests/engine/test_computed_prefix_blocks.py tests/basic_correctness/test_chunked_prefill.py tests/basic_correctness/test_basic_correctness.py +tests/basic_correctness/test_basic_server_correctness.py tests/test_cache_block_hashing.py diff --git a/neuralmagic/tests/skip-for-remote-push-tmp.txt b/neuralmagic/tests/skip-for-remote-push-tmp.txt index cc601630521e3..f0e5a8f6ab575 100644 --- a/neuralmagic/tests/skip-for-remote-push-tmp.txt +++ b/neuralmagic/tests/skip-for-remote-push-tmp.txt @@ -99,6 +99,7 @@ tests/engine/output_processor/test_multi_step.py tests/engine/test_computed_prefix_blocks.py tests/basic_correctness/test_chunked_prefill.py tests/basic_correctness/test_preemption.py +tests/basic_correctness/test_basic_server_correctness.py tests/test_cache_block_hashing.py tests/test_logger.py tests/test_regression.py diff --git a/neuralmagic/tests/skip-for-remote-push.txt b/neuralmagic/tests/skip-for-remote-push.txt index 9907f8b51c66e..c52db98f46a1d 100644 --- a/neuralmagic/tests/skip-for-remote-push.txt +++ b/neuralmagic/tests/skip-for-remote-push.txt @@ -49,4 +49,5 @@ tests/lora/test_lora.py tests/worker/test_model_runner.py tests/engine/test_detokenize.py tests/engine/test_computed_prefix_blocks.py +tests/basic_correctness/test_basic_server_correctness.py tests/accuracy/test_lm_eval_correctness.py diff --git a/requirements-dev.txt b/requirements-dev.txt index e0418fc7c250f..6d67854aa2f94 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -13,6 +13,8 @@ types-requests==2.31.0.2 types-setuptools # testing +autoawq +datasets pytest tensorizer==2.9.0 pytest-forked diff --git a/tests/basic_correctness/__init__.py b/tests/basic_correctness/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/basic_correctness/test_basic_server_correctness.py b/tests/basic_correctness/test_basic_server_correctness.py new file mode 100644 index 0000000000000..7405c11b60b72 --- /dev/null +++ b/tests/basic_correctness/test_basic_server_correctness.py @@ -0,0 +1,187 @@ +import asyncio +from os import getenv +from typing import Dict, List, Type + +import openai +import pytest +import torch +from datasets import load_dataset +from openai import AsyncOpenAI +from transformers import AutoTokenizer + +from tests.conftest import HfRunnerNM +from tests.models.compare_utils import check_logprobs_close +from tests.utils.logging import make_logger +from tests.utils.server import ServerContext +from vllm.model_executor.layers.quantization import get_quantization_config + + +@pytest.fixture(scope="session") +def client(): + client = openai.AsyncOpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + yield client + + +@pytest.fixture +def hf_runner_nm() -> Type[HfRunnerNM]: + return HfRunnerNM + + +async def my_chat( + client, + model: str, + messages: List[Dict], + max_tokens: int, + temperature: float, + num_logprobs: int, +): + """ submit a single prompt chat and collect results. """ + return await client.chat.completions.create(model=model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + logprobs=True, + top_logprobs=num_logprobs) + + +@pytest.mark.parametrize("model, max_model_len, sparsity, gptq_config", [ + ("mistralai/Mistral-7B-Instruct-v0.2", 4096, None, None), + ("neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50", 4096, "sparse_w16a16", + None), + ("NousResearch/Llama-2-7b-chat-hf", 4096, None, None), + ("neuralmagic/Llama-2-7b-pruned70-retrained-ultrachat", 4096, + "sparse_w16a16", None), + ("Qwen/Qwen1.5-7B-Chat", 4096, None, None), + ("casperhansen/gemma-7b-it-awq", 4096, None, "gptq_marlin"), + ("mistralai/Mixtral-8x7B-Instruct-v0.1", 4096, None, None), + ("Qwen/Qwen1.5-MoE-A2.7B-Chat", 4096, None, None), +]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [3]) +@pytest.mark.parametrize("tensor_parallel_size", [None]) +# note: repeating the test for 2 values of tensor_parallel_size +# increases the overall execution time by unnecessarily +# collecting the HuggingFace runner data twice. +# Consider refactoring to eliminate that repeat. +def test_models_on_server( + hf_runner_nm: HfRunnerNM, + client: AsyncOpenAI, + model: str, + max_model_len: int, + sparsity: str, + gptq_config: str, + tensor_parallel_size: int, + max_tokens: int, + num_logprobs: int, +) -> None: + """ + This test compares the output of the vllm OpenAI server against that of + a HuggingFace transformer. We expect them to be fairly close. "Close" + is measured by checking that the top 3 logprobs for each token includes + the token of the other inference tool. The first time that there is no + exact match, as long as there is a match to one of the top `num_logprobs` + logprobs, the test will not proceed further, but will pass. + + Parameters to the test identify a model to test, and key arguments + required for that model (see the `max_model_len`, `sparsity` and + `gptq_config` params below). The additional parametrizations expand test + coverage across the functional space of the server. + + :param hf_runner_nm: fixture for the HfRunnerNM + :param client: fixture with an openai.AsyncOpenAI client + :param model: The Hugginface id for a model to test with + :param max_model_len: passed to the vllm Server's --max-model-len option + :param sparsity: passed to the vllm Server's --sparsity option + :param gptq_config: quantization method id for this model. default None + means quantization isn't involved. + :param tensor_parallel_size: passed to the vllm Server's + --tensor_parallel_size option + :param max_tokens: the total number of tokens to consider for closeness + :param num_logprobs: the total number of logprobs included when + calculating closeness + """ + logger = make_logger("vllm_test") + # check that the requested gpu count is available in the test env + gpu_count = torch.cuda.device_count() + if tensor_parallel_size and gpu_count < tensor_parallel_size: + pytest.skip(f"gpu count {gpu_count} is insufficient for " + f"tensor_parallel_size = {tensor_parallel_size}") + + # skip this model if the current device does not have the required + # gpu capability. + device_capability = torch.cuda.get_device_capability() + capability = device_capability[0] * 10 + device_capability[1] + if gptq_config and ( + capability < + get_quantization_config(gptq_config).get_min_capability()): + pytest.skip("insufficient system GPU device capability " + f"({capability}) for this model") + + hf_token = getenv("HF_TOKEN", None) + logger.info("loading chat prompts for testing.") + ds = load_dataset("nm-testing/qa-chat-prompts", split="train_sft") + num_chat_turns = 3 + messages_list = [row["messages"][:num_chat_turns] for row in ds] + tokenizer = AutoTokenizer.from_pretrained(model) + chat_prompts = [ + tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + for messages in messages_list + ] + + logger.info("generating chat responses from HuggingFace runner.") + hf_model = hf_runner_nm(model, access_token=hf_token) + hf_outputs = hf_model.generate_greedy_logprobs_nm_use_tokens( + chat_prompts, max_tokens, num_logprobs, ignore_special_tokens=True) + + del hf_model + + logger.info("generating chat responses from vllm server.") + api_server_args = { + "--model": model, + "--max-model-len": max_model_len, + "--disable-log-requests": None, + } + if sparsity: + api_server_args["--sparsity"] = sparsity + if tensor_parallel_size: + api_server_args["--tensor-parallel-size"] = tensor_parallel_size + + # some devices will require a different `dtype` + if device_capability[0] < 8: + api_server_args["--dtype"] = "half" + + asyncio_event_loop = asyncio.get_event_loop() + temperature = 0.0 + with ServerContext(api_server_args, logger=logger) as _: + # submit an asynchronous request to the server for each prompt + chats = [ + my_chat(client, model, messages, max_tokens, temperature, + num_logprobs) + for messages in [query for query in messages_list] + ] + # await for all the requests to return, and gather their results + # in one place + results = asyncio_event_loop.run_until_complete(asyncio.gather(*chats)) + + logger.info("preparing results from vllm server requests to include " + "tokens and logprobs.") + vllm_outputs = list() + for task_result in results: + for req_output in task_result.choices: + output_str = req_output.message.content + output_tokens = req_output.logprobs.model_extra["tokens"] + output_logprobs = req_output.logprobs.model_extra["top_logprobs"] + vllm_outputs.append((output_tokens, output_str, output_logprobs)) + + logger.info("comparing HuggingFace and vllm Server chat responses") + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf_model", + name_1="vllm_model", + ) diff --git a/tests/conftest.py b/tests/conftest.py index 3b557df55c8a0..6c264b09ecb75 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,7 @@ import gc import logging import os -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import pytest import torch @@ -143,6 +143,7 @@ def __init__( model_name: str, tokenizer_name: Optional[str] = None, dtype: str = "half", + access_token: Optional[str] = None, ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -152,6 +153,7 @@ def __init__( model_name, torch_dtype=torch_dtype, trust_remote_code=True, + token=access_token, ).cuda() self.processor = None else: @@ -159,6 +161,7 @@ def __init__( model_name, torch_dtype=torch_dtype, trust_remote_code=True, + token=access_token, ).cuda() self.processor = AutoProcessor.from_pretrained( model_name, @@ -287,12 +290,31 @@ def hf_runner(): # UPSTREAM SYNC: needed for nm-automation class HfRunnerNM(HfRunner): + def _logprobs_from_generated_hidden_states(self, output): + """ + generates a list of logprobs from the output of self.model.generate() + """ + seq_logprobs = [] + for _, hidden_states in enumerate(output.hidden_states): + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if self.model.get_output_embeddings().bias is not None: + logits += self.model.get_output_embeddings().bias.unsqueeze(0) + logprobs = torch.nn.functional.log_softmax(logits, + dim=-1, + dtype=torch.float32) + seq_logprobs.append(logprobs) + return seq_logprobs + def generate_greedy_logprobs_nm( self, prompts: List[str], max_tokens: int, num_logprobs: int, - ) -> List[Tuple[List[int], str]]: + ) -> List[Tuple[List[int], str, List[Dict]]]: all_logprobs = [] all_output_ids = [] all_output_strs = [] @@ -308,20 +330,7 @@ def generate_greedy_logprobs_nm( return_dict_in_generate=True, ) - seq_logprobs = [] - for _, hidden_states in enumerate(output.hidden_states): - last_hidden_states = hidden_states[-1][0] - logits = torch.matmul( - last_hidden_states, - self.model.get_output_embeddings().weight.t(), - ) - if self.model.get_output_embeddings().bias is not None: - logits += self.model.get_output_embeddings( - ).bias.unsqueeze(0) - logprobs = torch.nn.functional.log_softmax(logits, - dim=-1, - dtype=torch.float32) - seq_logprobs.append(logprobs) + seq_logprobs = self._logprobs_from_generated_hidden_states(output) # convert to dict seq_logprobs_lst = [] @@ -348,6 +357,107 @@ def generate_greedy_logprobs_nm( return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] + SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>", "�", ""] + + def _decode_token_by_position_index( + self, + token_index: int, + token_ids: List[int], + clean_up_tokenization_spaces: bool = True, + ignore_special_tokens: bool = False) -> str: + """ + helper function to calculate a token string at the specified index + based on the previous tokens. + + :param token_index: position in the list of token ids where you would + like the token string + :param token_ids: the list of token ids + :param clean_up_tokenization_spaces: option to pass to + `tokenizer.decode()` + :param ignore_special_tokens: converts hard coded special tokens to + an empty string + """ + lookback = 4 + prior_str = self.tokenizer.decode( + token_ids[token_index - lookback:token_index - 1], + clean_up_tokenization_spaces=clean_up_tokenization_spaces) + current_str = self.tokenizer.decode( + token_ids[token_index - lookback:token_index], + clean_up_tokenization_spaces=clean_up_tokenization_spaces) + token = current_str[-(len(current_str) - len(prior_str)):] + if ignore_special_tokens and token in self.SPECIAL_TOKENS: + token = "" + return token + + def generate_greedy_logprobs_nm_use_tokens( + self, + prompts: List[str], + max_tokens: int, + topk_logprobs_count: int, + ignore_special_tokens: bool = False + ) -> List[Tuple[List[int], str, List[Dict]]]: + all_logprobs = [] + all_output_tokens = [] + all_output_strs = [] + + for prompt in prompts: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + output = self.model.generate( + input_ids.cuda(), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + seq_logprobs = self._logprobs_from_generated_hidden_states(output) + + # convert sequence of logprobs to a dict keyed on the selected token + seq_ids = output.sequences[0] + input_len = input_ids.shape[1] + seq_logprobs_lst = list() + output_tokens = list() + + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(topk_logprobs_count) + + tok_logprobs_dct = {} + # add 1 to the index here to be in sync with the skipped prompt + # in the logprobs + indexed_seq = seq_ids[:input_len + tok_idx + 1].tolist() + token_str = self._decode_token_by_position_index( + input_len + tok_idx + 1, + indexed_seq, + clean_up_tokenization_spaces=False, + ignore_special_tokens=True) + output_tokens.append(token_str) + for alt_token_id, logprob in zip(topk.indices[0], + topk.values[0]): + # replace the token_str at the tok_idx with alt_token_id. + indexed_seq[-1] = alt_token_id + # then run decode again + logprob_str = self._decode_token_by_position_index( + input_len + tok_idx + 1, + indexed_seq, + clean_up_tokenization_spaces=False, + ignore_special_tokens=True) + tok_logprobs_dct[logprob_str] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + + all_output_tokens.append(output_tokens) + all_output_strs.append("".join(output_tokens)) + + outputs = zip(all_output_tokens, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + def __del__(self): del self.model cleanup() diff --git a/tests/models/compare_utils.py b/tests/models/compare_utils.py index 44319b6ca45ff..051cbf1547b21 100644 --- a/tests/models/compare_utils.py +++ b/tests/models/compare_utils.py @@ -1,4 +1,4 @@ -"""Compare the logprobs of two sequences generated by different models, +"""Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. """ @@ -19,13 +19,19 @@ def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1): if output_id_0 != output_id_1: # Each predicted token must be in top N logprobs of the other assert output_id_0 in logprobs_1[idx], ( - f"Test{prompt_idx}:" + f"{name_0} token '{output_id_0}' not in " + f"[{list(logprobs_1[idx].keys())}]" + f"\nprompt index {prompt_idx}, token index {idx}:" f"\n{name_0}:\t{output_str_0!r}" f"\n{name_1}:\t{output_str_1!r}") assert output_id_1 in logprobs_0[idx], ( - f"Test{prompt_idx}:" + f"{name_1} token '{output_id_1}' not in " + f"[{list(logprobs_0[idx].keys())}]" + f"\nprompt index {prompt_idx}, token index {idx}:" f"\n{name_0}:\t{output_str_0!r}" f"\n{name_1}:\t{output_str_1!r}") # Break out since sequences will now diverge. + # as long as we got this far with the output tokens being the + # same, or close, the responses are close enough break