Skip to content

Commit

Permalink
Run v1 benchmark and integrate with PyTorch OSS benchmark database (v…
Browse files Browse the repository at this point in the history
…llm-project#13068)

Signed-off-by: Huy Do <[email protected]>
  • Loading branch information
huydhn authored and kerthcet committed Feb 21, 2025
1 parent 846c663 commit 4a8bc59
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ main() {
check_gpus
check_hf_token

# Set to v1 to run v1 benchmark
if [[ "${ENGINE_VERSION:-v0}" == "v1" ]]; then
export VLLM_USE_V1=1
fi

# dependencies
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
(which jq) || (apt-get update && apt-get -y install jq)
Expand Down
2 changes: 1 addition & 1 deletion .buildkite/nightly-benchmarks/tests/latency-tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@
"num-iters": 15
}
}
]
]
91 changes: 58 additions & 33 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
"""Benchmark the latency of processing a single batch of requests."""

import argparse
import dataclasses
import json
import os
import time
from pathlib import Path
from typing import List, Optional
from typing import Any, Dict, List, Optional

import numpy as np
import torch
from benchmark_utils import convert_to_pytorch_benchmark_format
from tqdm import tqdm

from vllm import LLM, SamplingParams
Expand All @@ -18,6 +21,19 @@
from vllm.utils import FlexibleArgumentParser


def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: Dict[str, Any]) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={"latency": results["latencies"]},
extra_info={k: results[k]
for k in ["avg_latency", "percentiles"]})
if pt_records:
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
with open(pt_file, "w") as f:
json.dump(pt_records, f)


def main(args: argparse.Namespace):
print(args)

Expand Down Expand Up @@ -54,7 +70,8 @@ def llm_generate():
beam_width=args.n,
max_tokens=args.output_len,
ignore_eos=True,
))
),
)

def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
Expand All @@ -64,7 +81,8 @@ def run_to_completion(profile_dir: Optional[str] = None):
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
str(profile_dir)),
) as p:
llm_generate()
print(p.key_averages().table(sort_by="self_cuda_time_total"))
else:
Expand All @@ -81,9 +99,8 @@ def run_to_completion(profile_dir: Optional[str] = None):
if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
profile_dir = Path(
"."
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
profile_dir = (Path(".") / "vllm_benchmark_result" /
f"latency_result_{time.time()}")
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return
Expand All @@ -95,9 +112,9 @@ def run_to_completion(profile_dir: Optional[str] = None):
latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90, 99]
percentiles = np.percentile(latencies, percentages)
print(f'Avg latency: {np.mean(latencies)} seconds')
print(f"Avg latency: {np.mean(latencies)} seconds")
for percentage, percentile in zip(percentages, percentiles):
print(f'{percentage}% percentile latency: {percentile} seconds')
print(f"{percentage}% percentile latency: {percentile} seconds")

# Output JSON results if specified
if args.output_json:
Expand All @@ -108,43 +125,51 @@ def run_to_completion(profile_dir: Optional[str] = None):
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)


if __name__ == '__main__':
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters-warmup',
type=int,
default=10,
help='Number of iterations to run for warmup.')
parser.add_argument('--num-iters',
description="Benchmark the latency of processing a single batch of "
"requests till completion.")
parser.add_argument("--input-len", type=int, default=32)
parser.add_argument("--output-len", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument(
"--n",
type=int,
default=1,
help="Number of generated sequences per prompt.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
"--num-iters-warmup",
type=int,
default=10,
help="Number of iterations to run for warmup.",
)
parser.add_argument("--num-iters",
type=int,
default=30,
help='Number of iterations to run.')
help="Number of iterations to run.")
parser.add_argument(
'--profile',
action='store_true',
help='profile the generation process of a single batch')
"--profile",
action="store_true",
help="profile the generation process of a single batch",
)
parser.add_argument(
'--profile-result-dir',
"--profile-result-dir",
type=str,
default=None,
help=('path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'))
help=("path to save the pytorch profiler output. Can be visualized "
"with ui.perfetto.dev or Tensorboard."),
)
parser.add_argument(
'--output-json',
"--output-json",
type=str,
default=None,
help='Path to save the latency results in JSON format.')
help="Path to save the latency results in JSON format.",
)

parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand Down
48 changes: 39 additions & 9 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser

from benchmark_utils import convert_to_pytorch_benchmark_format

MILLISECONDS_TO_SECONDS_CONVERSION = 1000


Expand Down Expand Up @@ -402,21 +404,21 @@ async def get_request(
burstiness: float = 1.0,
) -> AsyncGenerator[Tuple[str, int, int], None]:
"""
Asynchronously generates requests at a specified rate
Asynchronously generates requests at a specified rate
with OPTIONAL burstiness.
Args:
input_requests:
input_requests:
A list of input requests, each represented as a tuple.
request_rate:
request_rate:
The rate at which requests are generated (requests/s).
burstiness (optional):
The burstiness factor of the request generation.
burstiness (optional):
The burstiness factor of the request generation.
Only takes effect when request_rate is not inf.
Default value is 1, which follows a Poisson process.
Otherwise, the request intervals follow a gamma distribution.
A lower burstiness value (0 < burstiness < 1) results
in more bursty requests, while a higher burstiness value
A lower burstiness value (0 < burstiness < 1) results
in more bursty requests, while a higher burstiness value
(burstiness > 1) results in a more uniform arrival of requests.
"""
input_requests = iter(input_requests)
Expand Down Expand Up @@ -817,6 +819,32 @@ def parse_goodput(slo_pairs):
return goodput_config_dict


def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: Dict[str, Any],
file_name: str) -> None:
metrics = [
"median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms",
"mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms",
"median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms"
]
# These raw data might be useful, but they are rather big. They can be added
# later if needed
ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"]
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={k: [results[k]]
for k in metrics},
extra_info={
k: results[k]
for k in results if k not in metrics and k not in ignored_metrics
})
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json"
with open(pt_file, "w") as f:
json.dump(pt_records, f)


def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
Expand Down Expand Up @@ -997,6 +1025,7 @@ def main(args: argparse.Namespace):
file_name = os.path.join(args.result_dir, file_name)
with open(file_name, "w", encoding='utf-8') as outfile:
json.dump(result_json, outfile)
save_to_pytorch_benchmark_format(args, result_json, file_name)


if __name__ == "__main__":
Expand All @@ -1014,7 +1043,8 @@ def main(args: argparse.Namespace):
default=None,
help="Server or API base url if not using http host and port.",
)
parser.add_argument("--host", type=str, default="localhost")
# Use 127.0.0.1 here instead of localhost to force the use of ipv4
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--endpoint",
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/benchmark_serving_guided.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,8 @@ def main(args: argparse.Namespace):
default=None,
help="Server or API base url if not using http host and port.",
)
parser.add_argument("--host", type=str, default="localhost")
# Use 127.0.0.1 here instead of localhost to force the use of ipv4
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--endpoint",
Expand Down
24 changes: 23 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import argparse
import dataclasses
import json
import os
import random
import time
from functools import cache
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

import torch
import uvloop
from benchmark_utils import convert_to_pytorch_benchmark_format
from PIL import Image
from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
Expand Down Expand Up @@ -338,6 +340,25 @@ def run_mii(
return end - start


def save_to_pytorch_benchmark_format(args: argparse.Namespace,
results: Dict[str, Any]) -> None:
pt_records = convert_to_pytorch_benchmark_format(
args=args,
metrics={
"requests_per_second": [results["requests_per_second"]],
"tokens_per_second": [results["tokens_per_second"]],
},
extra_info={
k: results[k]
for k in ["elapsed_time", "num_requests", "total_num_tokens"]
})
if pt_records:
# Don't use json suffix here as we don't want CI to pick it up
pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json"
with open(pt_file, "w") as f:
json.dump(pt_records, f)


def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
Expand Down Expand Up @@ -435,6 +456,7 @@ def main(args: argparse.Namespace):
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
save_to_pytorch_benchmark_format(args, results)


if __name__ == "__main__":
Expand Down
39 changes: 39 additions & 0 deletions benchmarks/benchmark_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0

import argparse
import os
from typing import Any, Dict, List


def convert_to_pytorch_benchmark_format(args: argparse.Namespace,
metrics: Dict[str, List],
extra_info: Dict[str, Any]) -> List:
"""
Save the benchmark results in the format used by PyTorch OSS benchmark with
on metric per record
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
"""
records = []
if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
return records

for name, benchmark_values in metrics.items():
record = {
"benchmark": {
"name": "vLLM benchmark",
"extra_info": {
"args": vars(args),
},
},
"model": {
"name": args.model,
},
"metric": {
"name": name,
"benchmark_values": benchmark_values,
"extra_info": extra_info,
},
}
records.append(record)

return records

0 comments on commit 4a8bc59

Please sign in to comment.