-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Ce Gao <[email protected]>
- Loading branch information
Showing
8 changed files
with
585 additions
and
51 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
""" | ||
This script uploads JSONL files to the server, which can be used to run | ||
batch inference on the VLLM model. | ||
""" | ||
|
||
import argparse | ||
import time | ||
from pathlib import Path | ||
|
||
import rich | ||
from openai import OpenAI | ||
|
||
# get the current directory | ||
current_dir = Path(__file__).parent | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="CLI arguments for OpenAI API configuration." | ||
) | ||
parser.add_argument( | ||
"--openai-api-key", type=str, default="NULL", help="Your OpenAI API key" | ||
) | ||
parser.add_argument( | ||
"--openai-api-base", | ||
type=str, | ||
default="http://localhost:8000/v1", | ||
help="Base URL for OpenAI API", | ||
) | ||
parser.add_argument( | ||
"--file-path", | ||
type=str, | ||
default="batch.jsonl", | ||
help="Path to the JSONL file to upload", | ||
) | ||
args = parser.parse_args() | ||
|
||
openai_api_key = args.openai_api_key | ||
openai_api_base = args.openai_api_base | ||
|
||
# generate this file using `./generate_file.sh` | ||
filepath = current_dir / args.file_path | ||
|
||
client = OpenAI( | ||
api_key=openai_api_key, | ||
base_url=openai_api_base, | ||
) | ||
|
||
file = client.files.create( | ||
file=filepath.read_bytes(), | ||
purpose="batch", | ||
) | ||
|
||
# get the file according to the file id | ||
retrieved = client.files.retrieve(file.id) | ||
print("Retrieved file:") | ||
rich.print(retrieved) | ||
|
||
file_content = client.files.content(file.id) | ||
print("File content:") | ||
rich.print(file_content.read().decode()) | ||
file_content.close() | ||
|
||
# create a batch job | ||
batch = client.batches.create( | ||
input_file_id=file.id, | ||
endpoint="/completions", | ||
completion_window="1h", | ||
) | ||
print("Created batch job:") | ||
rich.print(batch) | ||
|
||
# retrieve the batch job | ||
retrieved_batch = client.batches.retrieve(batch.id) | ||
print("Retrieved batch job:") | ||
rich.print(retrieved_batch) | ||
|
||
# list all batch jobs | ||
batches = client.batches.list() | ||
print("List of batch jobs:") | ||
rich.print(batches) | ||
|
||
# wait for the batch job to complete | ||
while retrieved_batch.status == "pending": | ||
time.sleep(5) | ||
retrieved_batch = client.batches.retrieve(batch.id) | ||
|
||
# get the output file content | ||
output_file = client.files.retrieve(retrieved_batch.output_file_id) | ||
print("Output file:") | ||
rich.print(output_file) | ||
|
||
output_file_content = client.files.content(output_file.id) | ||
print("Output file content:") | ||
rich.print(output_file_content.read().decode()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from vllm_router.batch.batch import BatchEndpoint, BatchInfo, BatchRequest, BatchStatus | ||
from vllm_router.batch.processor import BatchProcessor | ||
from vllm_router.files import Storage | ||
|
||
|
||
def initialize_batch_processor( | ||
batch_processor_name: str, storage_path: str, storage: Storage | ||
) -> BatchProcessor: | ||
if batch_processor_name == "local": | ||
from vllm_router.batch.local_processor import LocalBatchProcessor | ||
|
||
return LocalBatchProcessor(storage_path, storage) | ||
else: | ||
raise ValueError(f"Unknown batch processor: {batch_processor_name}") | ||
|
||
|
||
__all__ = [ | ||
"BatchEndpoint", | ||
"BatchInfo", | ||
"BatchRequest", | ||
"BatchStatus", | ||
"BatchProcessor", | ||
"initialize_batch_processor", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from typing import Any, Dict, Optional | ||
|
||
|
||
class BatchStatus(str, Enum): | ||
""" | ||
Represents the status of a batch job. | ||
""" | ||
|
||
PENDING = "pending" | ||
RUNNING = "running" | ||
COMPLETED = "completed" | ||
FAILED = "failed" | ||
|
||
|
||
class BatchEndpoint(str, Enum): | ||
""" | ||
Represents the available OpenAI API endpoints for batch requests. | ||
Ref https://platform.openai.com/docs/api-reference/batch/create#batch-create-endpoint. | ||
""" | ||
|
||
CHAT_COMPLETION = "/v1/chat/completions" | ||
EMBEDDING = "/v1/embeddings" | ||
COMPLETION = "/v1/completions" | ||
|
||
|
||
@dataclass | ||
class BatchRequest: | ||
"""Represents a single request in a batch""" | ||
|
||
input_file_id: str | ||
endpoint: BatchEndpoint | ||
completion_window: str | ||
metadata: Optional[Dict[str, Any]] = None | ||
|
||
|
||
@dataclass | ||
class BatchInfo: | ||
""" | ||
Represents batch job information | ||
Ref https://platform.openai.com/docs/api-reference/batch/object | ||
""" | ||
|
||
id: str | ||
status: BatchStatus | ||
input_file_id: str | ||
created_at: int | ||
endpoint: str | ||
completion_window: str | ||
output_file_id: Optional[str] = None | ||
error_file_id: Optional[str] = None | ||
in_progress_at: Optional[int] = None | ||
expires_at: Optional[int] = None | ||
finalizing_at: Optional[int] = None | ||
completed_at: Optional[int] = None | ||
failed_at: Optional[int] = None | ||
expired_at: Optional[int] = None | ||
cancelling_at: Optional[int] = None | ||
cancelled_at: Optional[int] = None | ||
total_requests: Optional[int] = None | ||
completed_requests: int = 0 | ||
failed_requests: int = 0 | ||
metadata: Optional[Dict[str, Any]] = None | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
"""Convert the instance to a dictionary.""" | ||
return { | ||
"id": self.id, | ||
"status": self.status.value, | ||
"input_file_id": self.input_file_id, | ||
"created_at": self.created_at, | ||
"endpoint": self.endpoint, | ||
"completion_window": self.completion_window, | ||
"output_file_id": self.output_file_id, | ||
"error_file_id": self.error_file_id, | ||
"in_progress_at": self.in_progress_at, | ||
"expires_at": self.expires_at, | ||
"finalizing_at": self.finalizing_at, | ||
"completed_at": self.completed_at, | ||
"failed_at": self.failed_at, | ||
"expired_at": self.expired_at, | ||
"cancelling_at": self.cancelling_at, | ||
"cancelled_at": self.cancelled_at, | ||
"total_requests": self.total_requests, | ||
"completed_requests": self.completed_requests, | ||
"failed_requests": self.failed_requests, | ||
"metadata": self.metadata, | ||
} |
Oops, something went wrong.