Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for prompt adapters #21

Merged
merged 2 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
pyv: ["3.11"]
vllm_version:
# - "" # skip the pypi version as it will not work on CPU
- "git+https://github.com/vllm-project/[email protected].0.post1"
- "git+https://github.com/vllm-project/[email protected].1"
- "git+https://github.com/vllm-project/vllm@main"
- "git+https://github.com/opendatahub-io/vllm@main"

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
requires-python = ">=3.9"
dynamic = ["version"]
dependencies = [
"vllm>=0.5.0",
"vllm>=0.5.1",
"prometheus_client==0.20.0",
"grpcio==1.62.2",
"grpcio-health-checking==1.62.2",
Expand Down
12 changes: 11 additions & 1 deletion src/vllm_tgis_adapter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,18 @@ async def run_http_server(
args.lora_modules,
args.chat_template,
)

kwargs = {}
# prompt adapter arg required for vllm >0.5.1
if hasattr(args, "prompt_adapters"):
kwargs = {"prompt_adapters": args.prompt_adapters}

openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules
engine,
model_config,
served_model_names,
args.lora_modules,
**kwargs,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine, model_config, served_model_names
Expand Down
34 changes: 26 additions & 8 deletions src/vllm_tgis_adapter/grpc/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
from pathlib import Path
from typing import TYPE_CHECKING

from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest

from .validation import TGISValidationError

if TYPE_CHECKING:
from vllm.entrypoints.grpc.pb.generation_pb2 import (
BatchedGenerationRequest,
SingleGenerationRequest,
)
from vllm.lora.request import LoRARequest

from .validation import TGISValidationError

global_thread_pool = None # used for loading adapter files from disk

Expand All @@ -33,6 +35,7 @@ class AdapterMetadata:
unique_id: int # Unique integer for vllm to identify the adapter
adapter_type: str # The string name of the peft adapter type, e.g. LORA
full_path: str
full_config: dict # The loaded adapter_config.json dict


@dataclasses.dataclass
Expand All @@ -45,7 +48,7 @@ class AdapterStore:
async def validate_adapters(
request: SingleGenerationRequest | BatchedGenerationRequest,
adapter_store: AdapterStore | None,
) -> dict[str, LoRARequest]:
) -> dict[str, LoRARequest | PromptAdapterRequest]:
"""Validate the adapters.

Takes the adapter name from the request and constructs a valid
Expand All @@ -56,6 +59,9 @@ async def validate_adapters(
"""
global global_thread_pool # noqa: PLW0603
adapter_id = request.adapter_id
# Backwards compatibility for `prefix_id` arg
if not adapter_id and request.prefix_id:
adapter_id = request.prefix_id

if adapter_id and not adapter_store:
TGISValidationError.AdaptersDisabled.error()
Expand All @@ -73,18 +79,20 @@ async def validate_adapters(
if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)

adapter_type = await loop.run_in_executor(
adapter_config = await loop.run_in_executor(
global_thread_pool,
_get_adapter_type_from_file,
_load_adapter_config_from_file,
adapter_id,
local_adapter_path,
)
adapter_type = adapter_config.get("peft_type", None)

# Add to cache
adapter_metadata = AdapterMetadata(
unique_id=adapter_store.next_unique_id,
adapter_type=adapter_type,
full_path=local_adapter_path,
full_config=adapter_config,
)
adapter_store.adapters[adapter_id] = adapter_metadata

Expand All @@ -96,12 +104,22 @@ async def validate_adapters(
lora_local_path=adapter_metadata.full_path,
)
return {"lora_request": lora_request}
if adapter_metadata.adapter_type == "PROMPT_TUNING":
prompt_adapter_request = PromptAdapterRequest(
prompt_adapter_id=adapter_metadata.unique_id,
prompt_adapter_name=adapter_id,
prompt_adapter_local_path=adapter_metadata.full_path,
prompt_adapter_num_virtual_tokens=adapter_metadata.full_config.get(
"num_virtual_tokens", 0
),
)
return {"prompt_adapter_request": prompt_adapter_request}

# All other types unsupported
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type) # noqa: RET503


def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
def _load_adapter_config_from_file(adapter_id: str, adapter_path: str) -> dict:
"""Get adapter from file.

Performs all the filesystem access required to deduce the type
Expand All @@ -123,7 +141,7 @@ def _get_adapter_type_from_file(adapter_id: str, adapter_path: str) -> str:
with open(adapter_config_path) as adapter_config_file:
adapter_config = json.load(adapter_config_file)

return adapter_config.get("peft_type", None)
return adapter_config


def _reject_bad_adapter_id(adapter_id: str) -> None:
Expand Down
75 changes: 39 additions & 36 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from vllm.engine.async_llm_engine import _AsyncLLMEngine
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.inputs import TextTokensPrompt
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)

from vllm_tgis_adapter.logging import init_logger
from vllm_tgis_adapter.tgis_utils import logs
Expand All @@ -36,7 +41,6 @@
TGISStatLogger,
)

from .adapters import AdapterStore, validate_adapters
from .pb import generation_pb2_grpc
from .pb.generation_pb2 import DESCRIPTOR as _GENERATION_DESCRIPTOR
from .pb.generation_pb2 import (
Expand All @@ -52,15 +56,11 @@
from .validation import validate_input, validate_params

try:
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
log_tracing_disabled_warning,
)
from .adapters import AdapterStore, validate_adapters
except ImportError:
_vllm_tracing_available = False
adapters_available = False
else:
_vllm_tracing_available = True
adapters_available = True


if TYPE_CHECKING:
Expand All @@ -83,6 +83,11 @@
SingleGenerationRequest,
)

try:
from .adapters import PromptAdapterRequest
except ImportError:
pass

_T = TypeVar("_T")
_F = TypeVar("_F", Callable, Coroutine)

Expand Down Expand Up @@ -177,9 +182,11 @@ def __init__(
self.skip_special_tokens = not args.output_special_tokens
self.default_include_stop_seqs = args.default_include_stop_seqs

# Backwards compatibility for TGIS: PREFIX_STORE_PATH
adapter_cache_path = args.adapter_cache or args.prefix_store_path
self.adapter_store = (
AdapterStore(cache_path=args.adapter_cache, adapters={})
if args.adapter_cache
AdapterStore(cache_path=adapter_cache_path, adapters={})
if adapter_cache_path
else None
)
self.health_servicer = health_servicer
Expand All @@ -191,22 +198,11 @@ async def post_init(self) -> None:
assert self.tokenizer is not None

# Swap in the special TGIS stats logger
if hasattr(self.engine.engine, "stat_logger"):
# vllm <=0.5.1
tgis_stats_logger = TGISStatLogger(
vllm_stat_logger=self.engine.engine.stat_logger,
max_sequence_len=self.config.max_model_len,
)
self.engine.engine.stat_logger = tgis_stats_logger
elif hasattr(self.engine.engine, "stat_loggers"):
# vllm>=0.5.2
tgis_stats_logger = TGISStatLogger(
vllm_stat_logger=self.engine.engine.stat_loggers["prometheus"],
max_sequence_len=self.config.max_model_len,
)
self.engine.engine.stat_loggers["prometheus"] = tgis_stats_logger
else:
raise ValueError("engine doesn't have any known loggers.")
tgis_stats_logger = TGISStatLogger(
vllm_stat_logger=self.engine.engine.stat_loggers["prometheus"],
max_sequence_len=self.config.max_model_len,
)
self.engine.engine.stat_loggers["prometheus"] = tgis_stats_logger

self.health_servicer.set(
self.SERVICE_NAME,
Expand All @@ -231,7 +227,11 @@ async def Generate(
generators = []
max_is_token_limit = [False] * request_count

adapter_kwargs = await self._validate_adapters(request, context)
adapter_kwargs = (
await self._validate_adapters(request, context)
if adapters_available
else {}
)

for i, req in enumerate(request.requests):
input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize(
Expand All @@ -243,13 +243,12 @@ async def Generate(
prompt_token_ids=input_ids,
)
kwargs = {}
if _vllm_tracing_available:
is_tracing_enabled = await self.engine.is_tracing_enabled()
headers = dict(context.invocation_metadata())
if is_tracing_enabled:
kwargs["trace_headers"] = extract_trace_headers(headers)
elif contains_trace_headers(headers):
log_tracing_disabled_warning()
is_tracing_enabled = await self.engine.is_tracing_enabled()
headers = dict(context.invocation_metadata())
if is_tracing_enabled:
kwargs["trace_headers"] = extract_trace_headers(headers)
elif contains_trace_headers(headers):
log_tracing_disabled_warning()
generators.append(
self.engine.generate(
inputs=inputs,
Expand Down Expand Up @@ -328,7 +327,11 @@ async def GenerateStream(
sampling_params, truncate_input_tokens, request.request.text, context
)

adapter_kwargs = await self._validate_adapters(request, context)
adapter_kwargs = (
await self._validate_adapters(request, context)
if adapters_available
else {}
)
inputs = TextTokensPrompt(
prompt=request.request.text, prompt_token_ids=input_ids
)
Expand Down Expand Up @@ -596,7 +599,7 @@ async def _validate_adapters(
self,
request: SingleGenerationRequest | BatchedGenerationRequest,
context: ServicerContext,
) -> dict[str, LoRARequest]:
) -> dict[str, LoRARequest | PromptAdapterRequest]:
try:
adapters = await validate_adapters(
request=request, adapter_store=self.adapter_store
Expand Down
4 changes: 4 additions & 0 deletions src/vllm_tgis_adapter/tgis_utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ def add_tgis_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--tls-client-ca-cert-path", type=str)
# add a path when peft adapters will be loaded from
parser.add_argument("--adapter-cache", type=str)
# backwards-compatibility support for tgis prompt tuning
parser.add_argument(
"--prefix-store-path", type=str, help="Deprecated, use --adapter-cache"
)

# TODO check/add other args here

Expand Down
8 changes: 1 addition & 7 deletions src/vllm_tgis_adapter/tgis_utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,7 @@

from prometheus_client import Counter, Gauge, Histogram
from vllm import RequestOutput
from vllm.engine.metrics import Stats

try:
from vllm.engine.metrics import StatLoggerBase
except ImportError:
# vllm<=0.5.1
from vllm.engine.metrics import StatLogger as StatLoggerBase
from vllm.engine.metrics import StatLoggerBase, Stats

from vllm_tgis_adapter.grpc.pb.generation_pb2 import (
BatchedTokenizeRequest,
Expand Down