From fb7a2245bff4ab20a848411cc8034fc76c864bf6 Mon Sep 17 00:00:00 2001 From: Kartik Khandelwal Date: Sat, 18 Nov 2023 12:30:03 -0500 Subject: [PATCH] add python, rest api test (#1278) * add python, rest api test * remove mistral, fix pylint * fix pylint requests import error --- ci/task/pylint.sh | 2 +- tests/python/api/test_python.py | 45 +++++++++++++++++++++ tests/python/api/test_rest.py | 71 +++++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 tests/python/api/test_python.py create mode 100644 tests/python/api/test_rest.py diff --git a/ci/task/pylint.sh b/ci/task/pylint.sh index 7d2a0d326b..fb07ba6087 100755 --- a/ci/task/pylint.sh +++ b/ci/task/pylint.sh @@ -9,7 +9,7 @@ export PYTHONPATH="./python:$PYTHONPATH" set -x # TVM Unity is a dependency to this testing -pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly +pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly requests pylint --jobs $NUM_THREADS ./python/ pylint --jobs $NUM_THREADS --recursive=y ./tests/python/ diff --git a/tests/python/api/test_python.py b/tests/python/api/test_python.py new file mode 100644 index 0000000000..ceba066a13 --- /dev/null +++ b/tests/python/api/test_python.py @@ -0,0 +1,45 @@ +# pylint: disable=missing-docstring +import pytest + +from mlc_chat import ChatModule, GenerationConfig +from mlc_chat.callback import StreamToStdout + +MODELS = ["Llama-2-7b-chat-hf-q4f16_1"] + + +@pytest.mark.parametrize("model", MODELS) +def test_chat_module_creation_and_generate(model: str): + chat_module = ChatModule(model=model) + _ = chat_module.generate( + prompt="How to make a cake?", + ) + print(f"Statistics: {chat_module.stats()}\n") + + +@pytest.mark.parametrize("model", MODELS) +def test_chat_module_creation_and_generate_with_stream(model: str): + chat_module = ChatModule(model=model) + _ = chat_module.generate( + prompt="How to make a cake?", + progress_callback=StreamToStdout(callback_interval=2), + ) + print(f"Statistics: {chat_module.stats()}\n") + + +@pytest.mark.parametrize( + "generation_config", + [ + GenerationConfig(temperature=0.7, presence_penalty=0.1, frequency_penalty=0.5, top_p=0.9), + GenerationConfig(stop=["cake", "make"], n=3), + GenerationConfig(max_gen_len=40, repetition_penalty=1.2), + ], +) +@pytest.mark.parametrize("model", MODELS) +def test_chat_module_generation_config(generation_config: GenerationConfig, model: str): + chat_module = ChatModule(model=model) + output = chat_module.generate( + prompt="How to make a cake?", + generation_config=generation_config, + ) + print(output) + print(f"Statistics: {chat_module.stats()}\n") diff --git a/tests/python/api/test_rest.py b/tests/python/api/test_rest.py new file mode 100644 index 0000000000..de6e2bb793 --- /dev/null +++ b/tests/python/api/test_rest.py @@ -0,0 +1,71 @@ +# pylint: disable=missing-docstring +import json +import os +import signal +import subprocess +import time + +import pytest +import requests + +MODELS = ["Llama-2-7b-chat-hf-q4f16_1"] + + +@pytest.fixture +def run_rest_server(model): + cmd = f"python -m mlc_chat.rest --model {model}" + print(cmd) + os.environ["PYTHONPATH"] = "./python" + with subprocess.Popen(cmd.split()) as server_proc: + # wait for server to start + while True: + try: + _ = requests.get("http://localhost:8000/stats", timeout=5) + break + except requests.exceptions.ConnectionError: + time.sleep(1) + yield + server_proc.send_signal(signal.SIGINT) + server_proc.wait() + + +@pytest.mark.usefixtures("run_rest_server") +@pytest.mark.parametrize("stream", [True, False]) +@pytest.mark.parametrize("model", MODELS) +def test_rest_api(model, stream): + payload = { + "model": model, + "messages": [ + { + "role": "user", + "content": "Hello, I am Bob", + }, + { + "role": "assistant", + "content": "Hello, I am a chatbot.", + }, + { + "role": "user", + "content": "What is my name?", + }, + ], + "stream": stream, + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "temperature": 1.0, + "top_p": 0.95, + } + if stream: + with requests.post( + "http://127.0.0.1:8000/v1/chat/completions", json=payload, stream=True, timeout=120 + ) as model_response: + print("With streaming:") + for chunk in model_response: + content = json.loads(chunk[6:-2])["choices"][0]["delta"].get("content", "") + print(f"{content}", end="", flush=True) + print("\n") + else: + model_response = requests.post( + "http://127.0.0.1:8000/v1/chat/completions", json=payload, timeout=120 + ) + print(f"\n{model_response.json()['choices'][0]['message']['content']}\n")