From ca51ffa4d78cb7191a6d7ef154c20213c15eb353 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Sat, 15 Feb 2025 13:23:30 +0800 Subject: [PATCH] feat: Support batch Signed-off-by: Ce Gao --- examples/batch.py | 42 ----- examples/openai_api_client_batch.py | 94 ++++++++++ src/vllm_router/batch/__init__.py | 24 +++ src/vllm_router/batch/batch.py | 91 ++++++++++ src/vllm_router/batch/local_processor.py | 207 +++++++++++++++++++++++ src/vllm_router/batch/processor.py | 45 +++++ src/vllm_router/files/file_storage.py | 2 +- src/vllm_router/router.py | 131 +++++++++++++- 8 files changed, 585 insertions(+), 51 deletions(-) delete mode 100644 examples/batch.py create mode 100644 examples/openai_api_client_batch.py create mode 100644 src/vllm_router/batch/__init__.py create mode 100644 src/vllm_router/batch/batch.py create mode 100644 src/vllm_router/batch/local_processor.py create mode 100644 src/vllm_router/batch/processor.py diff --git a/examples/batch.py b/examples/batch.py deleted file mode 100644 index 8f384d0..0000000 --- a/examples/batch.py +++ /dev/null @@ -1,42 +0,0 @@ -""" -This script uploads JSONL files to the server, which can be used to run -batch inference on the VLLM model. -""" - -from pathlib import Path - -import rich -from openai import OpenAI - -# get the current directory -current_dir = Path(__file__).parent -# generate this file using `./generate_file.sh` -filepath = current_dir / "batch.jsonl" - -# Modify OpenAI's API key and API base to use vLLM's API server. -openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8000/v1" - -client = OpenAI( - api_key=openai_api_key, - base_url=openai_api_base, -) - - -def from_in_memory() -> None: - file = client.files.create( - file=filepath.read_bytes(), - purpose="batch", - ) - return file - - -if __name__ == "__main__": - file = from_in_memory() - - # get the file according to the file id - retrieved = client.files.retrieve(file.id) - rich.print(retrieved) - - file_content = client.files.retrieve_content(file.id) - rich.print(file_content.encode("utf-8")) diff --git a/examples/openai_api_client_batch.py b/examples/openai_api_client_batch.py new file mode 100644 index 0000000..86edb43 --- /dev/null +++ b/examples/openai_api_client_batch.py @@ -0,0 +1,94 @@ +""" +This script uploads JSONL files to the server, which can be used to run +batch inference on the VLLM model. +""" + +import argparse +import time +from pathlib import Path + +import rich +from openai import OpenAI + +# get the current directory +current_dir = Path(__file__).parent + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="CLI arguments for OpenAI API configuration." + ) + parser.add_argument( + "--openai-api-key", type=str, default="NULL", help="Your OpenAI API key" + ) + parser.add_argument( + "--openai-api-base", + type=str, + default="http://localhost:8000/v1", + help="Base URL for OpenAI API", + ) + parser.add_argument( + "--file-path", + type=str, + default="batch.jsonl", + help="Path to the JSONL file to upload", + ) + args = parser.parse_args() + + openai_api_key = args.openai_api_key + openai_api_base = args.openai_api_base + + # generate this file using `./generate_file.sh` + filepath = current_dir / args.file_path + + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + + file = client.files.create( + file=filepath.read_bytes(), + purpose="batch", + ) + + # get the file according to the file id + retrieved = client.files.retrieve(file.id) + print("Retrieved file:") + rich.print(retrieved) + + file_content = client.files.content(file.id) + print("File content:") + rich.print(file_content.read().decode()) + file_content.close() + + # create a batch job + batch = client.batches.create( + input_file_id=file.id, + endpoint="/completions", + completion_window="1h", + ) + print("Created batch job:") + rich.print(batch) + + # retrieve the batch job + retrieved_batch = client.batches.retrieve(batch.id) + print("Retrieved batch job:") + rich.print(retrieved_batch) + + # list all batch jobs + batches = client.batches.list() + print("List of batch jobs:") + rich.print(batches) + + # wait for the batch job to complete + while retrieved_batch.status == "pending": + time.sleep(5) + retrieved_batch = client.batches.retrieve(batch.id) + + # get the output file content + output_file = client.files.retrieve(retrieved_batch.output_file_id) + print("Output file:") + rich.print(output_file) + + output_file_content = client.files.content(output_file.id) + print("Output file content:") + rich.print(output_file_content.read().decode()) diff --git a/src/vllm_router/batch/__init__.py b/src/vllm_router/batch/__init__.py new file mode 100644 index 0000000..1afc718 --- /dev/null +++ b/src/vllm_router/batch/__init__.py @@ -0,0 +1,24 @@ +from vllm_router.batch.batch import BatchEndpoint, BatchInfo, BatchRequest, BatchStatus +from vllm_router.batch.processor import BatchProcessor +from vllm_router.files import Storage + + +def initialize_batch_processor( + batch_processor_name: str, storage_path: str, storage: Storage +) -> BatchProcessor: + if batch_processor_name == "local": + from vllm_router.batch.local_processor import LocalBatchProcessor + + return LocalBatchProcessor(storage_path, storage) + else: + raise ValueError(f"Unknown batch processor: {batch_processor_name}") + + +__all__ = [ + "BatchEndpoint", + "BatchInfo", + "BatchRequest", + "BatchStatus", + "BatchProcessor", + "initialize_batch_processor", +] diff --git a/src/vllm_router/batch/batch.py b/src/vllm_router/batch/batch.py new file mode 100644 index 0000000..7f9b56e --- /dev/null +++ b/src/vllm_router/batch/batch.py @@ -0,0 +1,91 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Optional + + +class BatchStatus(str, Enum): + """ + Represents the status of a batch job. + """ + + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + + +class BatchEndpoint(str, Enum): + """ + Represents the available OpenAI API endpoints for batch requests. + + Ref https://platform.openai.com/docs/api-reference/batch/create#batch-create-endpoint. + """ + + CHAT_COMPLETION = "/v1/chat/completions" + EMBEDDING = "/v1/embeddings" + COMPLETION = "/v1/completions" + + +@dataclass +class BatchRequest: + """Represents a single request in a batch""" + + input_file_id: str + endpoint: BatchEndpoint + completion_window: str + metadata: Optional[Dict[str, Any]] = None + + +@dataclass +class BatchInfo: + """ + Represents batch job information + + Ref https://platform.openai.com/docs/api-reference/batch/object + """ + + id: str + status: BatchStatus + input_file_id: str + created_at: int + endpoint: str + completion_window: str + output_file_id: Optional[str] = None + error_file_id: Optional[str] = None + in_progress_at: Optional[int] = None + expires_at: Optional[int] = None + finalizing_at: Optional[int] = None + completed_at: Optional[int] = None + failed_at: Optional[int] = None + expired_at: Optional[int] = None + cancelling_at: Optional[int] = None + cancelled_at: Optional[int] = None + total_requests: Optional[int] = None + completed_requests: int = 0 + failed_requests: int = 0 + metadata: Optional[Dict[str, Any]] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert the instance to a dictionary.""" + return { + "id": self.id, + "status": self.status.value, + "input_file_id": self.input_file_id, + "created_at": self.created_at, + "endpoint": self.endpoint, + "completion_window": self.completion_window, + "output_file_id": self.output_file_id, + "error_file_id": self.error_file_id, + "in_progress_at": self.in_progress_at, + "expires_at": self.expires_at, + "finalizing_at": self.finalizing_at, + "completed_at": self.completed_at, + "failed_at": self.failed_at, + "expired_at": self.expired_at, + "cancelling_at": self.cancelling_at, + "cancelled_at": self.cancelled_at, + "total_requests": self.total_requests, + "completed_requests": self.completed_requests, + "failed_requests": self.failed_requests, + "metadata": self.metadata, + } diff --git a/src/vllm_router/batch/local_processor.py b/src/vllm_router/batch/local_processor.py new file mode 100644 index 0000000..c98a3a7 --- /dev/null +++ b/src/vllm_router/batch/local_processor.py @@ -0,0 +1,207 @@ +import asyncio +import datetime +import json +import os +import time +from typing import List, Optional +from uuid import uuid4 + +import aiosqlite + +from vllm_router.batch.batch import BatchInfo, BatchStatus +from vllm_router.batch.processor import BatchProcessor +from vllm_router.files import Storage +from vllm_router.log import init_logger + +logger = init_logger(__name__) + + +class LocalBatchProcessor(BatchProcessor): + """SQLite-backed batch processor with background processing.""" + + def __init__(self, db_dir: str, storage: Storage): + super().__init__(storage) + os.makedirs(db_dir, exist_ok=True) + self.db_path = os.path.join(db_dir, "batch_queue.db") + self._initialized = False + + async def initialize(self): + if not self._initialized: + logger.info( + "Initializing LocalBatchProcessor with SQLite DB at %s", self.db_path + ) + await self.setup_db() + asyncio.create_task(self.process_batches()) + self._initialized = True + + async def setup_db(self): + """Setup the SQLite table for all batch fields.""" + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + "CREATE TABLE IF NOT EXISTS batch_queue (" + "batch_id TEXT PRIMARY KEY, " + "status TEXT, " + "input_file_id TEXT, " + "created_at INTEGER, " + "endpoint TEXT, " + "completion_window TEXT, " + "output_file_id TEXT, " + "completed_at INTEGER, " + "metadata TEXT" + ")" + ) + await db.commit() + + async def create_batch( + self, + input_file_id: str, + endpoint: str, + completion_window: str, + metadata: Optional[dict] = None, + ) -> BatchInfo: + batch_id = "batch_" + uuid4().hex[:6] + ts_now = int(time.time()) + batch_info = BatchInfo( + id=batch_id, + status=BatchStatus.PENDING, + input_file_id=input_file_id, + created_at=ts_now, + endpoint=endpoint, + completion_window=completion_window, + output_file_id=None, + completed_at=None, + metadata=metadata or {}, + ) + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + "INSERT INTO batch_queue (batch_id, status, input_file_id, created_at, endpoint, completion_window, output_file_id, completed_at, metadata) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + ( + batch_id, + BatchStatus.PENDING, + input_file_id, + ts_now, + endpoint, + completion_window, + None, + None, + json.dumps(batch_info.metadata), + ), + ) + await db.commit() + logger.info("Created batch job %s", batch_id) + return batch_info + + async def retrieve_batch(self, batch_id: str) -> BatchInfo: + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + "SELECT status, input_file_id, created_at, endpoint, completion_window, output_file_id, completed_at, metadata FROM batch_queue WHERE batch_id = ?", + (batch_id,), + ) as cursor: + row = await cursor.fetchone() + if row is None: + raise ValueError(f"Batch {batch_id} not found") + ( + status, + input_file_id, + created_at, + endpoint, + completion_window, + output_file_id, + completed_at, + metadata, + ) = row + # Convert status string to BatchStatus enum. + from vllm_router.batch.batch import BatchStatus + + return BatchInfo( + id=batch_id, + status=BatchStatus(status), + input_file_id=input_file_id, + created_at=created_at, + endpoint=endpoint, + completion_window=completion_window, + output_file_id=output_file_id, + completed_at=completed_at, + metadata=json.loads(metadata) if metadata else {}, + ) + + async def list_batches( + self, limit: int = 100, after: str = None + ) -> List[BatchInfo]: + async with aiosqlite.connect(self.db_path) as db: + query = "SELECT batch_id FROM batch_queue" + params = () + if after: + query += " WHERE created_at > ?" + params = (after,) + query += " ORDER BY created_at DESC LIMIT ?" + params += (limit,) + async with db.execute(query, params) as cursor: + rows = await cursor.fetchall() + return [await self.retrieve_batch(row[0]) for row in rows] + + async def cancel_batch(self, batch_id: str) -> BatchInfo: + # Retrieve current batch info + batch_info = await self.retrieve_batch(batch_id) + if batch_info.status not in [BatchStatus.COMPLETED, BatchStatus.FAILED]: + batch_info.status = "cancelled" + batch_info.completed_at = int(time.time()) + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + "UPDATE batch_queue SET status = ?, completed_at = ? WHERE batch_id = ?", + (batch_info.status, batch_info.completed_at, batch_id), + ) + await db.commit() + return batch_info + + async def process_batches(self): + """Continuously poll the DB for pending batches and process them.""" + while True: + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + "SELECT batch_id FROM batch_queue WHERE status = ? ORDER BY created_at LIMIT 1", + (BatchStatus.PENDING,), + ) as cursor: + row = await cursor.fetchone() + if row is None: + # No pending batch; pause briefly + await asyncio.sleep(1) + continue + batch_id = row[0] + # Mark as processing + await db.execute( + "UPDATE batch_queue SET status = ? WHERE batch_id = ?", + ("running", batch_id), + ) + await db.commit() + try: + logger.info("Processing batch %s", batch_id) + # Simulate processing delay + await asyncio.sleep(1) + # Simulate generating output file via storage + result_content = ( + f"Processed batch {batch_id} at {datetime.datetime.utcnow()}" + ) + file_info = await self.storage.save_file( + file_name=f"{batch_id}_result.txt", + content=result_content.encode("utf-8"), + purpose="batch_output", + ) + completed_at = int(time.time()) + new_status = BatchStatus.COMPLETED + output_file_id = ( + file_info.id if hasattr(file_info, "id") else "output_id" + ) + except Exception as e: + logger.error("Failed processing batch %s: %s", batch_id, e) + new_status = BatchStatus.FAILED + completed_at = int(time.time()) + output_file_id = None + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + "UPDATE batch_queue SET status = ?, completed_at = ?, output_file_id = ? WHERE batch_id = ?", + (new_status, completed_at, output_file_id, batch_id), + ) + await db.commit() + # Short sleep to avoid tight polling loop if many jobs are queued + await asyncio.sleep(0.1) diff --git a/src/vllm_router/batch/processor.py b/src/vllm_router/batch/processor.py new file mode 100644 index 0000000..ce321da --- /dev/null +++ b/src/vllm_router/batch/processor.py @@ -0,0 +1,45 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from vllm_router.batch.batch import BatchInfo +from vllm_router.files import Storage + + +class BatchProcessor(ABC): + """Abstract base class for batch request processing""" + + def __init__(self, storage: Storage): + self.storage = storage + + @abstractmethod + async def initialize(self): + """Initialize the batch processor""" + pass + + @abstractmethod + async def create_batch( + self, + input_file_id: str, + endpoint: str, + completion_window: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> BatchInfo: + """Create a new batch job""" + pass + + @abstractmethod + async def retrieve_batch(self, batch_id: str) -> BatchInfo: + """Retrieve a specific batch job""" + pass + + @abstractmethod + async def list_batches( + self, limit: int = 100, after: str = None + ) -> List[BatchInfo]: + """List all batch jobs with pagination""" + pass + + @abstractmethod + async def cancel_batch(self, batch_id: str) -> BatchInfo: + """Cancel a running batch job""" + pass diff --git a/src/vllm_router/files/file_storage.py b/src/vllm_router/files/file_storage.py index aa163b0..d8d7ce2 100644 --- a/src/vllm_router/files/file_storage.py +++ b/src/vllm_router/files/file_storage.py @@ -23,7 +23,7 @@ class FileStorage(Storage): def __init__(self, base_path: str = "/tmp/vllm_files"): self.base_path = base_path - logger.info(f"Using local file storage at {base_path}") + logger.info("Initialize FileStorage with base path %s", base_path) os.makedirs(base_path, exist_ok=True) def _get_user_path(self, user_id: str) -> str: diff --git a/src/vllm_router/router.py b/src/vllm_router/router.py index 636a60d..19216fc 100644 --- a/src/vllm_router/router.py +++ b/src/vllm_router/router.py @@ -9,8 +9,9 @@ from fastapi import FastAPI, Request, UploadFile from fastapi.responses import JSONResponse, Response, StreamingResponse +from vllm_router.batch import BatchProcessor, initialize_batch_processor from vllm_router.engine_stats import GetEngineStatsScraper, InitializeEngineStatsScraper -from vllm_router.files import initialize_storage +from vllm_router.files import Storage, initialize_storage from vllm_router.httpx_client import HTTPXClientWrapper from vllm_router.protocols import ModelCard, ModelList from vllm_router.request_stats import ( @@ -34,6 +35,8 @@ @asynccontextmanager async def lifespan(app: FastAPI): httpx_client_wrapper.start() + if hasattr(app.state, "batch_processor"): + await app.state.batch_processor.initialize() yield await httpx_client_wrapper.stop() @@ -157,7 +160,8 @@ async def route_files(request: Request): file_content = await file_obj.read() try: - file_info = await FILE_STORAGE.save_file( + storage: Storage = app.state.batch_storage + file_info = await storage.save_file( file_name=file_obj.filename, content=file_content, purpose=purpose ) return JSONResponse(content=file_info.metadata()) @@ -170,7 +174,8 @@ async def route_files(request: Request): @app.get("/v1/files/{file_id}") async def route_get_file(file_id: str): try: - file = await FILE_STORAGE.get_file(file_id) + storage: Storage = app.state.batch_storage + file = await storage.get_file(file_id) return JSONResponse(content=file.metadata()) except FileNotFoundError: return JSONResponse( @@ -183,7 +188,8 @@ async def route_get_file_content(file_id: str): try: # TODO(gaocegege): Stream the file content with chunks to support # openai uploads interface. - file_content = await FILE_STORAGE.get_file_content(file_id) + storage: Storage = app.state.batch_storage + file_content = await storage.get_file_content(file_id) return Response(content=file_content) except FileNotFoundError: return JSONResponse( @@ -191,6 +197,99 @@ async def route_get_file_content(file_id: str): ) +@app.post("/v1/batches") +async def route_batches(request: Request): + """Handle batch requests that process files with specified endpoints.""" + try: + request_json = await request.json() + + # Validate required fields + if "input_file_id" not in request_json: + return JSONResponse( + status_code=400, + content={"error": "Missing required parameter 'input_file_id'"}, + ) + if "endpoint" not in request_json: + return JSONResponse( + status_code=400, + content={"error": "Missing required parameter 'endpoint'"}, + ) + + # Verify file exists + storage: Storage = app.state.batch_storage + file_id = request_json["input_file_id"] + try: + await storage.get_file(file_id) + except FileNotFoundError: + return JSONResponse( + status_code=404, content={"error": f"File {file_id} not found"} + ) + + batch_processor: BatchProcessor = app.state.batch_processor + batch = await batch_processor.create_batch( + input_file_id=file_id, + endpoint=request_json["endpoint"], + completion_window=request_json.get("completion_window", "5s"), + metadata=request_json.get("metadata", None), + ) + + # Return metadata as attribute, not a callable. + return JSONResponse(content=batch.to_dict()) + + except Exception as e: + return JSONResponse( + status_code=500, + content={"error": f"Failed to process batch request: {str(e)}"}, + ) + + +@app.get("/v1/batches/{batch_id}") +async def route_get_batch(batch_id: str): + try: + batch_processor: BatchProcessor = app.state.batch_processor + batch = await batch_processor.retrieve_batch(batch_id) + return JSONResponse(content=batch.to_dict()) + except FileNotFoundError: + return JSONResponse( + status_code=404, content={"error": f"Batch {batch_id} not found"} + ) + + +@app.get("/v1/batches") +async def route_list_batches(limit: int = 20, after: str = None): + try: + batch_processor: BatchProcessor = app.state.batch_processor + batches = await batch_processor.list_batches(limit=limit, after=after) + + # Convert batches to response format + batch_data = [batch.to_dict() for batch in batches] + + response = { + "object": "list", + "data": batch_data, + "first_id": batch_data[0]["id"] if batch_data else None, + "last_id": batch_data[-1]["id"] if batch_data else None, + "has_more": len(batch_data) + == limit, # If we got limit items, there may be more + } + + return JSONResponse(content=response) + except FileNotFoundError: + return JSONResponse(status_code=404, content={"error": "No batches found"}) + + +@app.delete("/v1/batches/{batch_id}") +async def route_cancel_batch(batch_id: str): + try: + batch_processor: BatchProcessor = app.state.batch_processor + batch = await batch_processor.cancel_batch(batch_id) + return JSONResponse(content=batch.to_dict()) + except FileNotFoundError: + return JSONResponse( + status_code=404, content={"error": f"Batch {batch_id} not found"} + ) + + @app.post("/v1/chat/completions") async def route_chat_completition(request: Request): return await route_general_request(request, "/v1/chat/completions") @@ -341,6 +440,11 @@ def parse_args(): # Batch API # TODO(gaocegege): Make these batch api related arguments to a separate config. + parser.add_argument( + "--enable-batch-api", + action="store_true", + help="Enable the batch API for processing files.", + ) parser.add_argument( "--file-storage-class", type=str, @@ -354,6 +458,13 @@ def parse_args(): default="/tmp/vllm_files", help="The path to store files.", ) + parser.add_argument( + "--batch-processor", + type=str, + default="local", + choices=["local"], + help="The batch processor to use.", + ) # Monitoring parser.add_argument( @@ -421,10 +532,14 @@ def InitializeAll(args): InitializeEngineStatsScraper(args.engine_stats_interval) InitializeRequestStatsMonitor(args.request_stats_window) - # TODO(gaocegege): Try adopting a more general way to initialize the - # storage, and global router. Maybe singleton? - global FILE_STORAGE - FILE_STORAGE = initialize_storage(args.file_storage_class, args.file_storage_path) + if args.enable_batch_api: + logger.info("Initializing batch API") + app.state.batch_storage = initialize_storage( + args.file_storage_class, args.file_storage_path + ) + app.state.batch_processor = initialize_batch_processor( + args.batch_processor, args.file_storage_path, app.state.batch_storage + ) InitializeRoutingLogic(args.routing_logic, session_key=args.session_key)