From 5f983859ab4c6eb0b4eb5a42aad892fb07688e38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniele=20Trifir=C3=B2?= Date: Wed, 10 Jul 2024 18:58:33 +0200 Subject: [PATCH 1/2] deps: bump vllm to vllm>=0.5.1 --- .github/workflows/tests.yaml | 2 +- pyproject.toml | 2 +- src/vllm_tgis_adapter/grpc/grpc_server.py | 51 +++++++-------------- src/vllm_tgis_adapter/tgis_utils/metrics.py | 8 +--- 4 files changed, 19 insertions(+), 44 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 9503ea2..98978e8 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -32,7 +32,7 @@ jobs: pyv: ["3.11"] vllm_version: # - "" # skip the pypi version as it will not work on CPU - - "git+https://github.com/vllm-project/vllm@v0.5.0.post1" + - "git+https://github.com/vllm-project/vllm@v0.5.1" - "git+https://github.com/vllm-project/vllm@main" - "git+https://github.com/opendatahub-io/vllm@main" diff --git a/pyproject.toml b/pyproject.toml index ad08da1..e06c59b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ requires-python = ">=3.9" dynamic = ["version"] dependencies = [ - "vllm>=0.5.0", + "vllm>=0.5.1", "prometheus_client==0.20.0", "grpcio==1.62.2", "grpcio-health-checking==1.62.2", diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index 5eddd36..bc5366c 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -20,6 +20,11 @@ from vllm.engine.async_llm_engine import _AsyncLLMEngine from vllm.entrypoints.openai.serving_completion import merge_async_iterators from vllm.inputs import TextTokensPrompt +from vllm.tracing import ( + contains_trace_headers, + extract_trace_headers, + log_tracing_disabled_warning, +) from vllm_tgis_adapter.logging import init_logger from vllm_tgis_adapter.tgis_utils import logs @@ -51,18 +56,6 @@ ) from .validation import validate_input, validate_params -try: - from vllm.tracing import ( - contains_trace_headers, - extract_trace_headers, - log_tracing_disabled_warning, - ) -except ImportError: - _vllm_tracing_available = False -else: - _vllm_tracing_available = True - - if TYPE_CHECKING: import argparse from collections.abc import AsyncIterator, MutableSequence @@ -191,22 +184,11 @@ async def post_init(self) -> None: assert self.tokenizer is not None # Swap in the special TGIS stats logger - if hasattr(self.engine.engine, "stat_logger"): - # vllm <=0.5.1 - tgis_stats_logger = TGISStatLogger( - vllm_stat_logger=self.engine.engine.stat_logger, - max_sequence_len=self.config.max_model_len, - ) - self.engine.engine.stat_logger = tgis_stats_logger - elif hasattr(self.engine.engine, "stat_loggers"): - # vllm>=0.5.2 - tgis_stats_logger = TGISStatLogger( - vllm_stat_logger=self.engine.engine.stat_loggers["prometheus"], - max_sequence_len=self.config.max_model_len, - ) - self.engine.engine.stat_loggers["prometheus"] = tgis_stats_logger - else: - raise ValueError("engine doesn't have any known loggers.") + tgis_stats_logger = TGISStatLogger( + vllm_stat_logger=self.engine.engine.stat_loggers["prometheus"], + max_sequence_len=self.config.max_model_len, + ) + self.engine.engine.stat_loggers["prometheus"] = tgis_stats_logger self.health_servicer.set( self.SERVICE_NAME, @@ -243,13 +225,12 @@ async def Generate( prompt_token_ids=input_ids, ) kwargs = {} - if _vllm_tracing_available: - is_tracing_enabled = await self.engine.is_tracing_enabled() - headers = dict(context.invocation_metadata()) - if is_tracing_enabled: - kwargs["trace_headers"] = extract_trace_headers(headers) - elif contains_trace_headers(headers): - log_tracing_disabled_warning() + is_tracing_enabled = await self.engine.is_tracing_enabled() + headers = dict(context.invocation_metadata()) + if is_tracing_enabled: + kwargs["trace_headers"] = extract_trace_headers(headers) + elif contains_trace_headers(headers): + log_tracing_disabled_warning() generators.append( self.engine.generate( inputs=inputs, diff --git a/src/vllm_tgis_adapter/tgis_utils/metrics.py b/src/vllm_tgis_adapter/tgis_utils/metrics.py index d7537ac..979403f 100644 --- a/src/vllm_tgis_adapter/tgis_utils/metrics.py +++ b/src/vllm_tgis_adapter/tgis_utils/metrics.py @@ -5,13 +5,7 @@ from prometheus_client import Counter, Gauge, Histogram from vllm import RequestOutput -from vllm.engine.metrics import Stats - -try: - from vllm.engine.metrics import StatLoggerBase -except ImportError: - # vllm<=0.5.1 - from vllm.engine.metrics import StatLogger as StatLoggerBase +from vllm.engine.metrics import StatLoggerBase, Stats from vllm_tgis_adapter.grpc.pb.generation_pb2 import ( BatchedTokenizeRequest, From b232f9e233e71f0baf627e6ab8f0d7e5a7b23f75 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 27 Jun 2024 12:34:25 -0700 Subject: [PATCH 2/2] =?UTF-8?q?=E2=9C=A8=20Prompt=20adapter=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- src/vllm_tgis_adapter/__main__.py | 12 +++++++- src/vllm_tgis_adapter/grpc/adapters.py | 34 +++++++++++++++++------ src/vllm_tgis_adapter/grpc/grpc_server.py | 34 +++++++++++++++++++---- src/vllm_tgis_adapter/tgis_utils/args.py | 4 +++ 4 files changed, 69 insertions(+), 15 deletions(-) diff --git a/src/vllm_tgis_adapter/__main__.py b/src/vllm_tgis_adapter/__main__.py index cf54695..72fd7e3 100644 --- a/src/vllm_tgis_adapter/__main__.py +++ b/src/vllm_tgis_adapter/__main__.py @@ -202,8 +202,18 @@ async def run_http_server( args.lora_modules, args.chat_template, ) + + kwargs = {} + # prompt adapter arg required for vllm >0.5.1 + if hasattr(args, "prompt_adapters"): + kwargs = {"prompt_adapters": args.prompt_adapters} + openai_serving_completion = OpenAIServingCompletion( - engine, model_config, served_model_names, args.lora_modules + engine, + model_config, + served_model_names, + args.lora_modules, + **kwargs, ) openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, served_model_names diff --git a/src/vllm_tgis_adapter/grpc/adapters.py b/src/vllm_tgis_adapter/grpc/adapters.py index 321968c..e56def4 100644 --- a/src/vllm_tgis_adapter/grpc/adapters.py +++ b/src/vllm_tgis_adapter/grpc/adapters.py @@ -14,14 +14,16 @@ from pathlib import Path from typing import TYPE_CHECKING +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest + +from .validation import TGISValidationError + if TYPE_CHECKING: from vllm.entrypoints.grpc.pb.generation_pb2 import ( BatchedGenerationRequest, SingleGenerationRequest, ) -from vllm.lora.request import LoRARequest - -from .validation import TGISValidationError global_thread_pool = None # used for loading adapter files from disk @@ -33,6 +35,7 @@ class AdapterMetadata: unique_id: int # Unique integer for vllm to identify the adapter adapter_type: str # The string name of the peft adapter type, e.g. LORA full_path: str + full_config: dict # The loaded adapter_config.json dict @dataclasses.dataclass @@ -45,7 +48,7 @@ class AdapterStore: async def validate_adapters( request: SingleGenerationRequest | BatchedGenerationRequest, adapter_store: AdapterStore | None, -) -> dict[str, LoRARequest]: +) -> dict[str, LoRARequest | PromptAdapterRequest]: """Validate the adapters. Takes the adapter name from the request and constructs a valid @@ -56,6 +59,9 @@ async def validate_adapters( """ global global_thread_pool # noqa: PLW0603 adapter_id = request.adapter_id + # Backwards compatibility for `prefix_id` arg + if not adapter_id and request.prefix_id: + adapter_id = request.prefix_id if adapter_id and not adapter_store: TGISValidationError.AdaptersDisabled.error() @@ -73,18 +79,20 @@ async def validate_adapters( if global_thread_pool is None: global_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) - adapter_type = await loop.run_in_executor( + adapter_config = await loop.run_in_executor( global_thread_pool, - _get_adapter_type_from_file, + _load_adapter_config_from_file, adapter_id, local_adapter_path, ) + adapter_type = adapter_config.get("peft_type", None) # Add to cache adapter_metadata = AdapterMetadata( unique_id=adapter_store.next_unique_id, adapter_type=adapter_type, full_path=local_adapter_path, + full_config=adapter_config, ) adapter_store.adapters[adapter_id] = adapter_metadata @@ -96,12 +104,22 @@ async def validate_adapters( lora_local_path=adapter_metadata.full_path, ) return {"lora_request": lora_request} + if adapter_metadata.adapter_type == "PROMPT_TUNING": + prompt_adapter_request = PromptAdapterRequest( + prompt_adapter_id=adapter_metadata.unique_id, + prompt_adapter_name=adapter_id, + prompt_adapter_local_path=adapter_metadata.full_path, + prompt_adapter_num_virtual_tokens=adapter_metadata.full_config.get( + "num_virtual_tokens", 0 + ), + ) + return {"prompt_adapter_request": prompt_adapter_request} # All other types unsupported TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type) # noqa: RET503 -def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str: +def _load_adapter_config_from_file(adapter_id: str, adapter_path: str) -> dict: """Get adapter from file. Performs all the filesystem access required to deduce the type @@ -123,7 +141,7 @@ def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str: with open(adapter_config_path) as adapter_config_file: adapter_config = json.load(adapter_config_file) - return adapter_config.get("peft_type", None) + return adapter_config def _reject_bad_adapter_id(adapter_id: str) -> None: diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index bc5366c..af309b9 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -41,7 +41,6 @@ TGISStatLogger, ) -from .adapters import AdapterStore, validate_adapters from .pb import generation_pb2_grpc from .pb.generation_pb2 import DESCRIPTOR as _GENERATION_DESCRIPTOR from .pb.generation_pb2 import ( @@ -56,6 +55,14 @@ ) from .validation import validate_input, validate_params +try: + from .adapters import AdapterStore, validate_adapters +except ImportError: + adapters_available = False +else: + adapters_available = True + + if TYPE_CHECKING: import argparse from collections.abc import AsyncIterator, MutableSequence @@ -76,6 +83,11 @@ SingleGenerationRequest, ) + try: + from .adapters import PromptAdapterRequest + except ImportError: + pass + _T = TypeVar("_T") _F = TypeVar("_F", Callable, Coroutine) @@ -170,9 +182,11 @@ def __init__( self.skip_special_tokens = not args.output_special_tokens self.default_include_stop_seqs = args.default_include_stop_seqs + # Backwards compatibility for TGIS: PREFIX_STORE_PATH + adapter_cache_path = args.adapter_cache or args.prefix_store_path self.adapter_store = ( - AdapterStore(cache_path=args.adapter_cache, adapters={}) - if args.adapter_cache + AdapterStore(cache_path=adapter_cache_path, adapters={}) + if adapter_cache_path else None ) self.health_servicer = health_servicer @@ -213,7 +227,11 @@ async def Generate( generators = [] max_is_token_limit = [False] * request_count - adapter_kwargs = await self._validate_adapters(request, context) + adapter_kwargs = ( + await self._validate_adapters(request, context) + if adapters_available + else {} + ) for i, req in enumerate(request.requests): input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize( @@ -309,7 +327,11 @@ async def GenerateStream( sampling_params, truncate_input_tokens, request.request.text, context ) - adapter_kwargs = await self._validate_adapters(request, context) + adapter_kwargs = ( + await self._validate_adapters(request, context) + if adapters_available + else {} + ) inputs = TextTokensPrompt( prompt=request.request.text, prompt_token_ids=input_ids ) @@ -577,7 +599,7 @@ async def _validate_adapters( self, request: SingleGenerationRequest | BatchedGenerationRequest, context: ServicerContext, - ) -> dict[str, LoRARequest]: + ) -> dict[str, LoRARequest | PromptAdapterRequest]: try: adapters = await validate_adapters( request=request, adapter_store=self.adapter_store diff --git a/src/vllm_tgis_adapter/tgis_utils/args.py b/src/vllm_tgis_adapter/tgis_utils/args.py index 9522ac3..d0b79a8 100644 --- a/src/vllm_tgis_adapter/tgis_utils/args.py +++ b/src/vllm_tgis_adapter/tgis_utils/args.py @@ -116,6 +116,10 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--tls-client-ca-cert-path", type=str) # add a path when peft adapters will be loaded from parser.add_argument("--adapter-cache", type=str) + # backwards-compatibility support for tgis prompt tuning + parser.add_argument( + "--prefix-store-path", type=str, help="Deprecated, use --adapter-cache" + ) # TODO check/add other args here