Skip to content

Commit

Permalink
Re-sample requests for each model
Browse files Browse the repository at this point in the history
  • Loading branch information
liu-cong committed Oct 18, 2024
1 parent b0be375 commit 69b139f
Showing 1 changed file with 12 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,18 @@ async def send_request(
async def benchmark(
args: argparse.Namespace,
api_url: str,
input_requests: List[Tuple[str, int, int]],
tokenizer: PreTrainedTokenizerBase,
model: str,
) -> Tuple[List[Tuple[int, int, float]], Dict[str, int]]:
"""Runs benchmark with asynchronous requests."""
input_requests = sample_requests(
args.dataset,
args.num_prompts,
args.max_input_length,
args.max_output_length,
tokenizer,
args.use_dummy_text,
)
benchmark_start_time = time.time()
tasks: List[asyncio.Task] = []
async for request in get_request(input_requests, args.request_rate):
Expand Down Expand Up @@ -586,19 +593,12 @@ async def main(args: argparse.Namespace):
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
)
input_requests = sample_requests(
args.dataset,
args.num_prompts,
args.max_input_length,
args.max_output_length,
tokenizer,
args.use_dummy_text,
)

benchmark_start_time = time.time()
args.start_datetime = datetime.fromtimestamp(benchmark_start_time)

results = await asyncio.gather(
*[benchmark(args, api_url, input_requests, tokenizer, model) for model in models]
*[benchmark(args, api_url, tokenizer, model) for model in models]
)

# Summarize results
Expand All @@ -618,7 +618,7 @@ async def main(args: argparse.Namespace):

benchmark_duration_all_models = time.time() - benchmark_start_time
if args.save_aggregated_result:
print_and_save_result(args, benchmark_duration_all_models, len(models)*len(input_requests), f"ALL-{len(models)}-MODELS", combined_latencies, combined_errors)
print_and_save_result(args, benchmark_duration_all_models, len(models)*args.num_prompts, f"ALL-{len(models)}-MODELS", combined_latencies, combined_errors)

if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -709,7 +709,7 @@ async def main(args: argparse.Namespace):
"the request arrival times."
),
)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--seed", type=int, default=int(time.time()))
parser.add_argument(
"--trust-remote-code",
action="store_true",
Expand Down

0 comments on commit 69b139f

Please sign in to comment.