Skip to content

Commit

Permalink
feat: Support batch
Browse files Browse the repository at this point in the history
Signed-off-by: Ce Gao <[email protected]>
  • Loading branch information
gaocegege committed Feb 15, 2025
1 parent 93ad1cd commit ca51ffa
Show file tree
Hide file tree
Showing 8 changed files with 585 additions and 51 deletions.
42 changes: 0 additions & 42 deletions examples/batch.py

This file was deleted.

94 changes: 94 additions & 0 deletions examples/openai_api_client_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
This script uploads JSONL files to the server, which can be used to run
batch inference on the VLLM model.
"""

import argparse
import time
from pathlib import Path

import rich
from openai import OpenAI

# get the current directory
current_dir = Path(__file__).parent

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="CLI arguments for OpenAI API configuration."
)
parser.add_argument(
"--openai-api-key", type=str, default="NULL", help="Your OpenAI API key"
)
parser.add_argument(
"--openai-api-base",
type=str,
default="http://localhost:8000/v1",
help="Base URL for OpenAI API",
)
parser.add_argument(
"--file-path",
type=str,
default="batch.jsonl",
help="Path to the JSONL file to upload",
)
args = parser.parse_args()

openai_api_key = args.openai_api_key
openai_api_base = args.openai_api_base

# generate this file using `./generate_file.sh`
filepath = current_dir / args.file_path

client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)

file = client.files.create(
file=filepath.read_bytes(),
purpose="batch",
)

# get the file according to the file id
retrieved = client.files.retrieve(file.id)
print("Retrieved file:")
rich.print(retrieved)

file_content = client.files.content(file.id)
print("File content:")
rich.print(file_content.read().decode())
file_content.close()

# create a batch job
batch = client.batches.create(
input_file_id=file.id,
endpoint="/completions",
completion_window="1h",
)
print("Created batch job:")
rich.print(batch)

# retrieve the batch job
retrieved_batch = client.batches.retrieve(batch.id)
print("Retrieved batch job:")
rich.print(retrieved_batch)

# list all batch jobs
batches = client.batches.list()
print("List of batch jobs:")
rich.print(batches)

# wait for the batch job to complete
while retrieved_batch.status == "pending":
time.sleep(5)
retrieved_batch = client.batches.retrieve(batch.id)

# get the output file content
output_file = client.files.retrieve(retrieved_batch.output_file_id)
print("Output file:")
rich.print(output_file)

output_file_content = client.files.content(output_file.id)
print("Output file content:")
rich.print(output_file_content.read().decode())
24 changes: 24 additions & 0 deletions src/vllm_router/batch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from vllm_router.batch.batch import BatchEndpoint, BatchInfo, BatchRequest, BatchStatus
from vllm_router.batch.processor import BatchProcessor
from vllm_router.files import Storage


def initialize_batch_processor(
batch_processor_name: str, storage_path: str, storage: Storage
) -> BatchProcessor:
if batch_processor_name == "local":
from vllm_router.batch.local_processor import LocalBatchProcessor

return LocalBatchProcessor(storage_path, storage)
else:
raise ValueError(f"Unknown batch processor: {batch_processor_name}")


__all__ = [
"BatchEndpoint",
"BatchInfo",
"BatchRequest",
"BatchStatus",
"BatchProcessor",
"initialize_batch_processor",
]
91 changes: 91 additions & 0 deletions src/vllm_router/batch/batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional


class BatchStatus(str, Enum):
"""
Represents the status of a batch job.
"""

PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"


class BatchEndpoint(str, Enum):
"""
Represents the available OpenAI API endpoints for batch requests.
Ref https://platform.openai.com/docs/api-reference/batch/create#batch-create-endpoint.
"""

CHAT_COMPLETION = "/v1/chat/completions"
EMBEDDING = "/v1/embeddings"
COMPLETION = "/v1/completions"


@dataclass
class BatchRequest:
"""Represents a single request in a batch"""

input_file_id: str
endpoint: BatchEndpoint
completion_window: str
metadata: Optional[Dict[str, Any]] = None


@dataclass
class BatchInfo:
"""
Represents batch job information
Ref https://platform.openai.com/docs/api-reference/batch/object
"""

id: str
status: BatchStatus
input_file_id: str
created_at: int
endpoint: str
completion_window: str
output_file_id: Optional[str] = None
error_file_id: Optional[str] = None
in_progress_at: Optional[int] = None
expires_at: Optional[int] = None
finalizing_at: Optional[int] = None
completed_at: Optional[int] = None
failed_at: Optional[int] = None
expired_at: Optional[int] = None
cancelling_at: Optional[int] = None
cancelled_at: Optional[int] = None
total_requests: Optional[int] = None
completed_requests: int = 0
failed_requests: int = 0
metadata: Optional[Dict[str, Any]] = None

def to_dict(self) -> Dict[str, Any]:
"""Convert the instance to a dictionary."""
return {
"id": self.id,
"status": self.status.value,
"input_file_id": self.input_file_id,
"created_at": self.created_at,
"endpoint": self.endpoint,
"completion_window": self.completion_window,
"output_file_id": self.output_file_id,
"error_file_id": self.error_file_id,
"in_progress_at": self.in_progress_at,
"expires_at": self.expires_at,
"finalizing_at": self.finalizing_at,
"completed_at": self.completed_at,
"failed_at": self.failed_at,
"expired_at": self.expired_at,
"cancelling_at": self.cancelling_at,
"cancelled_at": self.cancelled_at,
"total_requests": self.total_requests,
"completed_requests": self.completed_requests,
"failed_requests": self.failed_requests,
"metadata": self.metadata,
}
Loading

0 comments on commit ca51ffa

Please sign in to comment.