From 0087a5abe3cde96c5d4ee0013a0370b0c4681549 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Fri, 24 Jan 2025 08:34:46 +0000 Subject: [PATCH 1/3] [V1][Metrics] Add simple prometheus logger Part of #10582 Implement the vllm:num_requests_running and vllm:num_requests_waiting gauges from V0. This is a simple starting point from which to iterate towards parity with V0. There's no need to use prometheus_client's "multi-processing mode" (at least at this stage) because these metrics all exist within the API server process. Note this restores the following metrics - these were lost when we started using multi-processing mode: - python_gc_objects_collected_total - python_gc_objects_uncollectable_total - python_gc_collections_total - python_info - process_virtual_memory_bytes - process_resident_memory_bytes - process_start_time_seconds - process_cpu_seconds_total - process_open_fds - process_max_fds Signed-off-by: Mark McLoughlin --- vllm/v1/engine/async_llm.py | 11 +++++++---- vllm/v1/metrics/loggers.py | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 6dc68b3a16099..917d52d3220b8 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -24,7 +24,8 @@ from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import LoggingStatLogger, StatLoggerBase +from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, + StatLoggerBase) from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -46,13 +47,15 @@ def __init__( assert start_engine_loop + self.model_config = vllm_config.model_config + self.log_requests = log_requests self.log_stats = log_stats self.stat_loggers: List[StatLoggerBase] = [ LoggingStatLogger(), - # TODO(rob): PrometheusStatLogger(), + PrometheusStatLogger(labels=dict( + model_name=self.model_config.served_model_name)), ] - self.model_config = vllm_config.model_config # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -272,7 +275,7 @@ async def _run_output_handler(self): # 4) Logging. # TODO(rob): make into a coroutine and launch it in - # background thread once we add Prometheus. + # background thread once Prometheus overhead is non-trivial. assert iteration_stats is not None self._log_stats( scheduler_stats=outputs.scheduler_stats, diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 8feeef17542e6..486ab93965c18 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -1,5 +1,8 @@ import time from abc import ABC, abstractmethod +from typing import Dict + +import prometheus_client from vllm.logger import init_logger from vllm.v1.metrics.stats import SchedulerStats @@ -36,3 +39,27 @@ def log(self, scheduler_stats: SchedulerStats): scheduler_stats.num_running_reqs, scheduler_stats.num_waiting_reqs, ) + + +class PrometheusStatLogger(StatLoggerBase): + + def __init__(self, labels: Dict[str, str]): + self.labels = labels + + labelnames = self.labels.keys() + labelvalues = self.labels.values() + + self.gauge_scheduler_running = prometheus_client.Gauge( + name="vllm:num_requests_running", + documentation="Number of requests in model execution batches.", + labelnames=labelnames).labels(*labelvalues) + + self.gauge_scheduler_waiting = prometheus_client.Gauge( + name="vllm:num_requests_waiting", + documentation="Number of requests waiting to be processed.", + labelnames=labelnames).labels(*labelvalues) + + def log(self, scheduler_stats: SchedulerStats): + """Log to prometheus.""" + self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) + self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) From 3a5426a70a2b8de52a98a6190d3040b85b2c99d8 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Fri, 24 Jan 2025 12:12:16 -0500 Subject: [PATCH 2/3] [V1][Metrics] Add v1 test to tests/entrypoints/openai/test_metrics.py Not possible to use the run_with_both_engines fixture with its monkey-patching approach, since these tests use a module-level server fixture and the monkey-patching fixture cannot be used with a module-level fixture. Instead we can just pass env_dict = {'VLLM_USE_V1': '1'} to RemoteOpenAIServer, but copy the general approach of run_with_both_engines otherwise. Signed-off-by: Mark McLoughlin --- tests/entrypoints/openai/test_metrics.py | 41 ++++++++++++++++++++---- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 6523c8b6297c6..469a5fb039fb6 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -16,6 +16,24 @@ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" +@pytest.fixture(scope="module", params=[True, False]) +def use_v1(request): + # Module-scoped variant of run_with_both_engines + # + # Use this fixture to run a test with both v0 and v1, and + # also to conditionalize the test logic e.g. + # + # def test_metrics_exist(use_v1, server, client): + # ... + # expected = EXPECTED_V1_METRICS if use_v1 else EXPECTED_METRICS + # for metric in expected: + # assert metric in response.text + # + # @skip_v1 wouldn't work here because this is a module-level + # fixture - per-function decorators would have no effect + yield request.param + + @pytest.fixture(scope="module") def default_server_args(): return [ @@ -36,10 +54,12 @@ def default_server_args(): "--enable-chunked-prefill", "--disable-frontend-multiprocessing", ]) -def server(default_server_args, request): +def server(use_v1, default_server_args, request): if request.param: default_server_args.append(request.param) - with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + env_dict = dict(VLLM_USE_V1='1' if use_v1 else '0') + with RemoteOpenAIServer(MODEL_NAME, default_server_args, + env_dict=env_dict) as remote_server: yield remote_server @@ -84,7 +104,9 @@ async def client(server): @pytest.mark.asyncio async def test_metrics_counts(server: RemoteOpenAIServer, - client: openai.AsyncClient): + client: openai.AsyncClient, use_v1: bool): + if use_v1: + pytest.skip("Skipping test on vllm V1") for _ in range(_NUM_REQUESTS): # sending a request triggers the metrics to be logged. await client.completions.create( @@ -174,10 +196,15 @@ async def test_metrics_counts(server: RemoteOpenAIServer, "swap_space_bytes", ] +EXPECTED_METRICS_V1 = [ + "vllm:num_requests_running", + "vllm:num_requests_waiting", +] + @pytest.mark.asyncio async def test_metrics_exist(server: RemoteOpenAIServer, - client: openai.AsyncClient): + client: openai.AsyncClient, use_v1: bool): # sending a request triggers the metrics to be logged. await client.completions.create(model=MODEL_NAME, prompt="Hello, my name is", @@ -187,11 +214,13 @@ async def test_metrics_exist(server: RemoteOpenAIServer, response = requests.get(server.url_for("metrics")) assert response.status_code == HTTPStatus.OK - for metric in EXPECTED_METRICS: + for metric in (EXPECTED_METRICS_V1 if use_v1 else EXPECTED_METRICS): assert metric in response.text -def test_metrics_exist_run_batch(): +def test_metrics_exist_run_batch(use_v1: bool): + if use_v1: + pytest.skip("Skipping test on vllm V1") input_batch = """{"custom_id": "request-0", "method": "POST", "url": "/v1/embeddings", "body": {"model": "intfloat/e5-mistral-7b-instruct", "input": "You are a helpful assistant."}}""" # noqa: E501 base_url = "0.0.0.0" From 3ff78fe1a81ca831655208ac415ac5c58743d24d Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Sat, 25 Jan 2025 14:10:08 -0500 Subject: [PATCH 3/3] [V1][Metrics] Fix "Duplicated timeseries" error in test_async_llm.py Fixes: ValueError: Duplicated timeseries in CollectorRegistry: {'vllm:num_requests_running'} Same solution as in v0 - in case there are multiple engine instances in the same process (only in tests?), just de-register the metrics before registering them. Signed-off-by: Mark McLoughlin --- vllm/v1/metrics/loggers.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 486ab93965c18..b84f03fa3267c 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -49,6 +49,8 @@ def __init__(self, labels: Dict[str, str]): labelnames = self.labels.keys() labelvalues = self.labels.values() + self._unregister_vllm_metrics() + self.gauge_scheduler_running = prometheus_client.Gauge( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", @@ -63,3 +65,10 @@ def log(self, scheduler_stats: SchedulerStats): """Log to prometheus.""" self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) + + @staticmethod + def _unregister_vllm_metrics(): + # Unregister any existing vLLM collectors (for CI/CD + for collector in list(prometheus_client.REGISTRY._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + prometheus_client.REGISTRY.unregister(collector)