From 0e13eac4d60956642a2ab39ee56c2a35e1bb77a4 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 27 Jun 2024 12:34:25 -0700 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Prompt=20adapter=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- src/vllm_tgis_adapter/__main__.py | 12 +++++++++- src/vllm_tgis_adapter/grpc/adapters.py | 27 ++++++++++++++++++----- src/vllm_tgis_adapter/grpc/grpc_server.py | 9 +++++--- src/vllm_tgis_adapter/tgis_utils/args.py | 4 ++++ 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/src/vllm_tgis_adapter/__main__.py b/src/vllm_tgis_adapter/__main__.py index cf546954..72fd7e3c 100644 --- a/src/vllm_tgis_adapter/__main__.py +++ b/src/vllm_tgis_adapter/__main__.py @@ -202,8 +202,18 @@ async def run_http_server( args.lora_modules, args.chat_template, ) + + kwargs = {} + # prompt adapter arg required for vllm >0.5.1 + if hasattr(args, "prompt_adapters"): + kwargs = {"prompt_adapters": args.prompt_adapters} + openai_serving_completion = OpenAIServingCompletion( - engine, model_config, served_model_names, args.lora_modules + engine, + model_config, + served_model_names, + args.lora_modules, + **kwargs, ) openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, served_model_names diff --git a/src/vllm_tgis_adapter/grpc/adapters.py b/src/vllm_tgis_adapter/grpc/adapters.py index 321968ce..e3bd5178 100644 --- a/src/vllm_tgis_adapter/grpc/adapters.py +++ b/src/vllm_tgis_adapter/grpc/adapters.py @@ -20,6 +20,7 @@ SingleGenerationRequest, ) from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from .validation import TGISValidationError @@ -33,6 +34,7 @@ class AdapterMetadata: unique_id: int # Unique integer for vllm to identify the adapter adapter_type: str # The string name of the peft adapter type, e.g. LORA full_path: str + full_config: dict # The loaded adapter_config.json dict @dataclasses.dataclass @@ -45,7 +47,7 @@ class AdapterStore: async def validate_adapters( request: SingleGenerationRequest | BatchedGenerationRequest, adapter_store: AdapterStore | None, -) -> dict[str, LoRARequest]: +) -> dict[str, LoRARequest | PromptAdapterRequest]: """Validate the adapters. Takes the adapter name from the request and constructs a valid @@ -56,6 +58,9 @@ async def validate_adapters( """ global global_thread_pool # noqa: PLW0603 adapter_id = request.adapter_id + # Backwards compatibility for `prefix_id` arg + if not adapter_id and request.prefix_id: + adapter_id = request.prefix_id if adapter_id and not adapter_store: TGISValidationError.AdaptersDisabled.error() @@ -73,18 +78,20 @@ async def validate_adapters( if global_thread_pool is None: global_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) - adapter_type = await loop.run_in_executor( + adapter_config = await loop.run_in_executor( global_thread_pool, - _get_adapter_type_from_file, + _load_adapter_config_from_file, adapter_id, local_adapter_path, ) + adapter_type = adapter_config.get("peft_type", None) # Add to cache adapter_metadata = AdapterMetadata( unique_id=adapter_store.next_unique_id, adapter_type=adapter_type, full_path=local_adapter_path, + full_config=adapter_config, ) adapter_store.adapters[adapter_id] = adapter_metadata @@ -96,12 +103,22 @@ async def validate_adapters( lora_local_path=adapter_metadata.full_path, ) return {"lora_request": lora_request} + if adapter_metadata.adapter_type == "PROMPT_TUNING": + prompt_adapter_request = PromptAdapterRequest( + prompt_adapter_id=adapter_metadata.unique_id, + prompt_adapter_name=adapter_id, + prompt_adapter_local_path=adapter_metadata.full_path, + prompt_adapter_num_virtual_tokens=adapter_metadata.full_config.get( + "num_virtual_tokens", 0 + ), + ) + return {"prompt_adapter_request": prompt_adapter_request} # All other types unsupported TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type) # noqa: RET503 -def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str: +def _load_adapter_config_from_file(adapter_id: str, adapter_path: str) -> dict: """Get adapter from file. Performs all the filesystem access required to deduce the type @@ -123,7 +140,7 @@ def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str: with open(adapter_config_path) as adapter_config_file: adapter_config = json.load(adapter_config_file) - return adapter_config.get("peft_type", None) + return adapter_config def _reject_bad_adapter_id(adapter_id: str) -> None: diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index 576c4a76..69be1c4d 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -20,6 +20,7 @@ from vllm.engine.async_llm_engine import _AsyncLLMEngine from vllm.entrypoints.openai.serving_completion import merge_async_iterators from vllm.inputs import TextTokensPrompt +from vllm.prompt_adapter.request import PromptAdapterRequest # noqa: TCH002 from vllm_tgis_adapter.logging import init_logger from vllm_tgis_adapter.tgis_utils import logs @@ -177,9 +178,11 @@ def __init__( self.skip_special_tokens = not args.output_special_tokens self.default_include_stop_seqs = args.default_include_stop_seqs + # Backwards compatibility for TGIS: PREFIX_STORE_PATH + adapter_cache_path = args.adapter_cache or args.prefix_store_path self.adapter_store = ( - AdapterStore(cache_path=args.adapter_cache, adapters={}) - if args.adapter_cache + AdapterStore(cache_path=adapter_cache_path, adapters={}) + if adapter_cache_path else None ) self.health_servicer = health_servicer @@ -585,7 +588,7 @@ async def _validate_adapters( self, request: SingleGenerationRequest | BatchedGenerationRequest, context: ServicerContext, - ) -> dict[str, LoRARequest]: + ) -> dict[str, LoRARequest | PromptAdapterRequest]: try: adapters = await validate_adapters( request=request, adapter_store=self.adapter_store diff --git a/src/vllm_tgis_adapter/tgis_utils/args.py b/src/vllm_tgis_adapter/tgis_utils/args.py index 9522ac34..d0b79a86 100644 --- a/src/vllm_tgis_adapter/tgis_utils/args.py +++ b/src/vllm_tgis_adapter/tgis_utils/args.py @@ -116,6 +116,10 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--tls-client-ca-cert-path", type=str) # add a path when peft adapters will be loaded from parser.add_argument("--adapter-cache", type=str) + # backwards-compatibility support for tgis prompt tuning + parser.add_argument( + "--prefix-store-path", type=str, help="Deprecated, use --adapter-cache" + ) # TODO check/add other args here