Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Benchmark] guided decoding #10046

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 326 additions & 0 deletions benchmarks/benchmark_guided.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
"""Benchmark guided decoding throughput."""
import argparse
import json
import random
import time
from typing import List

import uvloop
from transformers import AutoTokenizer

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.sampling_params import GuidedDecodingParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators

SCHEMA = {
"$schema":
"https://json-schema.org/draft/2020-12/schema",
"title":
"User Profile",
"type":
"object",
"properties": {
"userId": {
"type": "string",
"description": "Unique identifier for the user."
},
"personalInfo": {
"type": "object",
"properties": {
"firstName": {
"type": "string",
"description": "The user's first name."
},
"lastName": {
"type": "string",
"description": "The user's last name."
},
"age": {
"type": "integer",
"minimum": 0,
"description": "The user's age."
},
"phoneNumbers": {
"type":
"array",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["home", "work", "mobile"],
"description": "Type of phone number."
},
"number": {
"type": "string",
"pattern": "^\\+?[1-9]\\d{1,14}$",
"description": "Phone number in E.164 format."
}
},
"required": ["type", "number"]
},
"description":
"List of phone numbers associated with the user."
}
},
"required": ["firstName", "lastName"]
},
"address": {
"type": "object",
"properties": {
"street": {
"type": "string",
"description": "Street address."
},
"city": {
"type": "string",
"description": "City name."
},
"state": {
"type": "string",
"description": "State or province."
},
"postalCode": {
"type": "string",
"pattern": "^\\d{5}(-\\d{4})?$",
"description": "Postal code."
},
"country": {
"type": "string",
"description": "Country name."
}
},
"required": ["street", "city", "state", "postalCode", "country"]
},
"preferences": {
"type": "object",
"properties": {
"newsletterSubscribed": {
"type":
"boolean",
"description":
"Indicates if the user is subscribed to the newsletter."
},
"favoriteCategories": {
"type": "array",
"items": {
"type": "string"
},
"description": "List of user's favorite categories."
}
},
"required": ["newsletterSubscribed"]
},
"accountStatus": {
"type": "string",
"enum": ["active", "inactive", "suspended"],
"description": "Current status of the user's account."
},
"registrationDate": {
"type": "string",
"format": "date-time",
"description": "ISO 8601 formatted date-time of user registration."
}
},
"required":
["userId", "personalInfo", "address", "accountStatus", "registrationDate"]
}


def run_vllm(
requests: List[tuple[str, int, int]],
engine_args: EngineArgs,
n: int,
guided_decoding: bool = False,
warmup: bool = False,
Copy link
Contributor

@xuechendi xuechendi Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about adding a 'result_file_name' with --save-results in argument list for result accuracy check

result_file_name = f"{'guided' if args.guided_decoding else 'no_guided'}_{args.num_prompts}_out{args.output_len}_async{args.async_engine}_warmup{args.warmup}.json"

outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
ret = []
for output in outputs:
    generated_text = output.outputs[0].text
    ret.append(generated_text)
    
# save ret list into a json
if result_file_name:
    with open(result_file_name, 'w') as f:
    json.dump(ret, f, indent=4)

) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**vars(engine_args))

# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
guided_decoding=GuidedDecodingParams(
json=SCHEMA) if guided_decoding else None,
))

start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
return end - start


async def run_vllm_async(
requests: List[tuple[str, int, int]],
engine_args: AsyncEngineArgs,
n: int,
guided_decoding: bool = False,
warmup: bool = False,
disable_frontend_multiprocessing: bool = False,
) -> float:
from vllm import SamplingParams

async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:

# Add the requests to the engine.
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
if warmup:
print("Running warmup prompt, for the first 5")
# We setup the first 5 requests to warmup FSM
warmup_requests = requests[:5]
requests = requests[5:]
for prompt, _, output_len in warmup_requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
guided_decoding=GuidedDecodingParams(
json=SCHEMA) if guided_decoding else None,
))
generators = []
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass

prompts = []
sampling_params = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=output_len,
guided_decoding=GuidedDecodingParams(
json=SCHEMA) if guided_decoding else None,
))

generators = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
end = time.perf_counter()
return end - start


def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)

# Synthesize a prompt with the given input length.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
prompt = f"Generate an example of a user profile given the following schema: {json.dumps(SCHEMA)}" # noqa: E501
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [(prompt, input_len, args.output_len)
for _ in range(args.num_prompts)]

if args.async_engine:
engine_args = AsyncEngineArgs.from_cli_args(args)
elapsed_time = uvloop.run(
run_vllm_async(
requests,
engine_args,
args.n,
args.guided_decoding,
args.warmup,
args.disable_frontend_multiprocessing,
))
else:
engine_args = EngineArgs.from_cli_args(args)
elapsed_time = run_vllm(
requests,
engine_args,
args.n,
args.guided_decoding,
args.warmup,
)

total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
total_output_tokens = sum(output_len for _, _, output_len in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s")

# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)


if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark guided decoding.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you replace most of these duplicated arguments with the engine's args, like we use in other benchmark scripts?

parser = AsyncEngineArgs.add_cli_args(parser)

parser = AsyncEngineArgs.add_cli_args(parser)

parser.add_argument("--output-len",
type=int,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts",
type=int,
default=10,
help="Number of prompts to process.")
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--guided-decoding",
action='store_true',
default=False,
help="Whether to enable JSON decoding or not.")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument("--warmup",
action="store_true",
default=False,
help="Run warmup prompts before benchmark.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
main(args)