Skip to content

Commit

Permalink
add python, rest api test (#1278)
Browse files Browse the repository at this point in the history
* add python, rest api test

* remove mistral, fix pylint

* fix pylint requests import error
  • Loading branch information
Kartik14 authored Nov 18, 2023
1 parent 31910dd commit fb7a224
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ci/task/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export PYTHONPATH="./python:$PYTHONPATH"
set -x

# TVM Unity is a dependency to this testing
pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly
pip install --quiet --pre -U -f https://mlc.ai/wheels mlc-ai-nightly requests

pylint --jobs $NUM_THREADS ./python/
pylint --jobs $NUM_THREADS --recursive=y ./tests/python/
45 changes: 45 additions & 0 deletions tests/python/api/test_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# pylint: disable=missing-docstring
import pytest

from mlc_chat import ChatModule, GenerationConfig
from mlc_chat.callback import StreamToStdout

MODELS = ["Llama-2-7b-chat-hf-q4f16_1"]


@pytest.mark.parametrize("model", MODELS)
def test_chat_module_creation_and_generate(model: str):
chat_module = ChatModule(model=model)
_ = chat_module.generate(
prompt="How to make a cake?",
)
print(f"Statistics: {chat_module.stats()}\n")


@pytest.mark.parametrize("model", MODELS)
def test_chat_module_creation_and_generate_with_stream(model: str):
chat_module = ChatModule(model=model)
_ = chat_module.generate(
prompt="How to make a cake?",
progress_callback=StreamToStdout(callback_interval=2),
)
print(f"Statistics: {chat_module.stats()}\n")


@pytest.mark.parametrize(
"generation_config",
[
GenerationConfig(temperature=0.7, presence_penalty=0.1, frequency_penalty=0.5, top_p=0.9),
GenerationConfig(stop=["cake", "make"], n=3),
GenerationConfig(max_gen_len=40, repetition_penalty=1.2),
],
)
@pytest.mark.parametrize("model", MODELS)
def test_chat_module_generation_config(generation_config: GenerationConfig, model: str):
chat_module = ChatModule(model=model)
output = chat_module.generate(
prompt="How to make a cake?",
generation_config=generation_config,
)
print(output)
print(f"Statistics: {chat_module.stats()}\n")
71 changes: 71 additions & 0 deletions tests/python/api/test_rest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# pylint: disable=missing-docstring
import json
import os
import signal
import subprocess
import time

import pytest
import requests

MODELS = ["Llama-2-7b-chat-hf-q4f16_1"]


@pytest.fixture
def run_rest_server(model):
cmd = f"python -m mlc_chat.rest --model {model}"
print(cmd)
os.environ["PYTHONPATH"] = "./python"
with subprocess.Popen(cmd.split()) as server_proc:
# wait for server to start
while True:
try:
_ = requests.get("http://localhost:8000/stats", timeout=5)
break
except requests.exceptions.ConnectionError:
time.sleep(1)
yield
server_proc.send_signal(signal.SIGINT)
server_proc.wait()


@pytest.mark.usefixtures("run_rest_server")
@pytest.mark.parametrize("stream", [True, False])
@pytest.mark.parametrize("model", MODELS)
def test_rest_api(model, stream):
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": "Hello, I am Bob",
},
{
"role": "assistant",
"content": "Hello, I am a chatbot.",
},
{
"role": "user",
"content": "What is my name?",
},
],
"stream": stream,
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 1.0,
"top_p": 0.95,
}
if stream:
with requests.post(
"http://127.0.0.1:8000/v1/chat/completions", json=payload, stream=True, timeout=120
) as model_response:
print("With streaming:")
for chunk in model_response:
content = json.loads(chunk[6:-2])["choices"][0]["delta"].get("content", "")
print(f"{content}", end="", flush=True)
print("\n")
else:
model_response = requests.post(
"http://127.0.0.1:8000/v1/chat/completions", json=payload, timeout=120
)
print(f"\n{model_response.json()['choices'][0]['message']['content']}\n")

0 comments on commit fb7a224

Please sign in to comment.