-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add handler for new lmi-dist #1595
Changes from all commits
807c754
980ce9e
e63eaaf
b0939f8
a03867d
161b814
f2821b0
445daef
4d5332e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
|
||
from pydantic.v1.class_validators import validator, root_validator | ||
|
||
from djl_python.properties_manager.properties import Properties | ||
from djl_python.properties_manager.properties import Properties, RollingBatchEnum | ||
|
||
|
||
class VllmQuantizeMethods(str, Enum): | ||
|
@@ -45,12 +45,19 @@ class VllmRbProperties(Properties): | |
draft_model_tp_size: int = 1 | ||
record_acceptance_rate: Optional[bool] = False | ||
|
||
@validator('engine') | ||
def validate_engine(cls, engine): | ||
if engine != "Python": | ||
@root_validator(skip_on_failure=True) | ||
def validate_engine(cls, properties): | ||
engine = properties["engine"] | ||
rolling_batch = properties["rolling_batch"] | ||
if rolling_batch == RollingBatchEnum.vllm and engine != "Python": | ||
raise AssertionError( | ||
f"Need python engine to start vLLM RollingBatcher") | ||
return engine | ||
|
||
if rolling_batch == RollingBatchEnum.lmidist_v2 and engine != "MPI": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we just have individual parameter checking for each rolling batch implementation? |
||
raise AssertionError( | ||
f"Need MPI engine to start lmidist_v2 RollingBatcher") | ||
|
||
return properties | ||
|
||
# TODO: Remove this once SageMaker resolved driver issue | ||
@root_validator(skip_on_failure=True) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
#!/usr/bin/env python | ||
# | ||
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file | ||
# except in compliance with the License. A copy of the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" | ||
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for | ||
# the specific language governing permissions and limitations under the License. | ||
|
||
from collections import OrderedDict | ||
|
||
from lmi_dist.api import Request | ||
from lmi_dist.init_engine import engine_from_args | ||
from vllm import EngineArgs, SamplingParams | ||
rohithkrn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
from djl_python.rolling_batch.vllm_rolling_batch_base import VllmRollingBatchBase, DTYPE_MAPPER | ||
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties | ||
|
||
|
||
class LmiDistRollingBatch(VllmRollingBatchBase): | ||
rohithkrn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
LmiDistRollingBatch connects handler to LmiDist backend engine. It receives new | ||
requests from the handler and sends them to the backend when space is available in the batch. | ||
It also gets any new tokens from the backend and sends them back to the handler. | ||
""" | ||
|
||
def __init__(self, model_id_or_path: str, properties: dict, **kwargs): | ||
""" | ||
Initializes the LmiDistRollingBatch. | ||
|
||
:param model_id_or_path (str): Currently unused since there is a copy inside properties | ||
:param properties (dict): other properties of the model, such as decoder strategy | ||
""" | ||
engine_config = VllmRbProperties(**properties) | ||
super().__init__(engine_config, kwargs.get("model_config", None)) | ||
self.init_engine() | ||
|
||
def init_engine(self): | ||
""" | ||
Initializes vllm engine | ||
""" | ||
args = EngineArgs( | ||
model=self.engine_config.model_id_or_path, | ||
tensor_parallel_size=self.engine_config.tensor_parallel_degree, | ||
dtype=DTYPE_MAPPER[self.engine_config.dtype], | ||
seed=0, | ||
max_model_len=self.engine_config.max_model_len, | ||
enforce_eager=self.engine_config.enforce_eager, | ||
gpu_memory_utilization=self.engine_config.gpu_memory_utilization, | ||
max_num_batched_tokens=self.engine_config. | ||
max_rolling_batch_prefill_tokens, | ||
trust_remote_code=self.engine_config.trust_remote_code, | ||
load_format=self.engine_config.load_format, | ||
quantization=self.engine_config.quantize, | ||
revision=self.engine_config.revision) | ||
self.engine = engine_from_args(args) | ||
|
||
def reset(self) -> None: | ||
""" | ||
Aborts all requests | ||
""" | ||
self.engine.reset(self.request_cache.keys()) | ||
self.request_cache = OrderedDict() | ||
super().reset() | ||
|
||
def add_request(self, request_id: str, prompt: str, | ||
sampling_params: SamplingParams): | ||
""" | ||
Adds request to the engine | ||
""" | ||
lmi_dist_request = Request(id=request_id, | ||
prompt=prompt, | ||
sampling_params=sampling_params) | ||
self.engine.add_request(lmi_dist_request) | ||
|
||
def translate_to_engine_params(self, parameters: dict): | ||
""" | ||
Helper function to convert DJL Serving parameter names to parameter names | ||
that lmidist_v2 recognizes. | ||
|
||
:param parameters (dict): Parameters pertaining to a specific request | ||
|
||
:return: The same parameters dict, but with VLLM style parameter names. | ||
""" | ||
parameters.pop('seed', None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seed is supported by vllm now: vllm-project/vllm#2514 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay, will remove. In this PR, I kept what vllm rolling batch does currently and wanted to tune params in the next PR. |
||
parameters.pop('do_sample', None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As I said above, I used vllm config here. But it's a good point, I believe this is removed for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, we can punt on this, but I think this is another point to bring up about how consistent of an interface we want to provide across engine/backend, vs. how closely the interface should change to match each engine/backend. |
||
if "max_new_tokens" in parameters.keys(): | ||
parameters["max_tokens"] = parameters.pop("max_new_tokens") | ||
if "stop_sequences" in parameters.keys(): | ||
parameters["stop"] = parameters.pop("stop_sequences") | ||
if "ignore_eos_token" in parameters.keys(): | ||
parameters["ignore_eos"] = parameters.pop("ignore_eos") | ||
return parameters | ||
|
||
def get_request_id(self, request): | ||
""" | ||
Get request id that will be set to backend engine request | ||
""" | ||
return str(request.id) | ||
|
||
def preprocess_requests(self, requests): | ||
""" | ||
Currently not applicable for VLLM. | ||
""" | ||
raise NotImplementedError( | ||
"Not implemented for lmidist_v2 rolling batcher") |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,31 +10,16 @@ | |
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS" | ||
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for | ||
# the specific language governing permissions and limitations under the License. | ||
import json | ||
import logging | ||
from collections import OrderedDict | ||
|
||
from vllm import EngineArgs, LLMEngine, SamplingParams | ||
from vllm.utils import random_uuid | ||
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, Token | ||
|
||
from djl_python.rolling_batch.vllm_rolling_batch_base import VllmRollingBatchBase, DTYPE_MAPPER | ||
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties | ||
|
||
DTYPE_MAPPER = { | ||
"fp32": "float32", | ||
"fp16": "float16", | ||
"bf16": "bfloat16", | ||
"auto": "auto" | ||
} | ||
|
||
FINISH_REASON_MAPPER = { | ||
"length": "length", | ||
"stop": "eos_token", | ||
"abort": "abort" | ||
} | ||
|
||
|
||
class VLLMRollingBatch(RollingBatch): | ||
class VLLMRollingBatch(VllmRollingBatchBase): | ||
""" | ||
VLLMRollingBatch connects the handler to the backend (VLLM inference). It receives new | ||
requests from the handler and sends them to the backend when there is space available in the batch. | ||
|
@@ -50,28 +35,32 @@ def __init__(self, model_id_or_path: str, properties: dict, | |
:param model_id_or_path: Currently unused since there is a copy inside properties | ||
:param properties: other properties of the model, such as decoder strategy | ||
""" | ||
self.vllm_configs = VllmRbProperties(**properties) | ||
super().__init__(waiting_steps=self.vllm_configs.waiting_steps, | ||
output_formatter=self.vllm_configs.output_formatter) | ||
engine_config = VllmRbProperties(**properties) | ||
super().__init__(engine_config, kwargs.get("model_config", None)) | ||
self.init_engine() | ||
|
||
def init_engine(self): | ||
""" | ||
Initializes vllm engine | ||
""" | ||
args = EngineArgs( | ||
model=self.vllm_configs.model_id_or_path, | ||
tensor_parallel_size=self.vllm_configs.tensor_parallel_degree, | ||
dtype=DTYPE_MAPPER[self.vllm_configs.dtype], | ||
model=self.engine_config.model_id_or_path, | ||
tensor_parallel_size=self.engine_config.tensor_parallel_degree, | ||
dtype=DTYPE_MAPPER[self.engine_config.dtype], | ||
seed=0, | ||
max_model_len=self.vllm_configs.max_model_len, | ||
enforce_eager=self.vllm_configs.enforce_eager, | ||
gpu_memory_utilization=self.vllm_configs.gpu_memory_utilization, | ||
max_num_batched_tokens=self.vllm_configs. | ||
max_model_len=self.engine_config.max_model_len, | ||
enforce_eager=self.engine_config.enforce_eager, | ||
gpu_memory_utilization=self.engine_config.gpu_memory_utilization, | ||
max_num_batched_tokens=self.engine_config. | ||
max_rolling_batch_prefill_tokens, | ||
trust_remote_code=self.vllm_configs.trust_remote_code, | ||
load_format=self.vllm_configs.load_format, | ||
quantization=self.vllm_configs.quantize, | ||
draft_model=self.vllm_configs.speculative_draft_model, | ||
speculate_length=self.vllm_configs.speculative_length, | ||
draft_model_tp_size=self.vllm_configs.draft_model_tp_size, | ||
revision=self.vllm_configs.revision) | ||
trust_remote_code=self.engine_config.trust_remote_code, | ||
load_format=self.engine_config.load_format, | ||
quantization=self.engine_config.quantize, | ||
draft_model=self.engine_config.speculative_draft_model, | ||
speculate_length=self.engine_config.speculative_length, | ||
draft_model_tp_size=self.engine_config.draft_model_tp_size, | ||
revision=self.engine_config.revision) | ||
self.engine = LLMEngine.from_engine_args(args) | ||
self.request_cache = OrderedDict() | ||
|
||
def reset(self) -> None: | ||
""" | ||
|
@@ -82,7 +71,14 @@ def reset(self) -> None: | |
self.request_cache = OrderedDict() | ||
super().reset() | ||
|
||
def translate_vllm_params(self, parameters: dict) -> dict: | ||
def add_request(self, request_id: str, prompt: str, | ||
sampling_params: SamplingParams): | ||
""" | ||
Adds request to the engine | ||
""" | ||
self.engine.add_request(request_id, prompt, sampling_params) | ||
|
||
def translate_to_engine_params(self, parameters: dict) -> dict: | ||
""" | ||
Helper function to convert DJL Serving parameter names to parameter names | ||
that VLLM recognizes. | ||
|
@@ -101,96 +97,11 @@ def translate_vllm_params(self, parameters: dict) -> dict: | |
parameters["ignore_eos"] = parameters.pop("ignore_eos") | ||
return parameters | ||
|
||
@stop_on_any_exception | ||
def inference(self, input_data: list[str], parameters: list[dict]) -> list: | ||
def get_request_id(self, request): | ||
""" | ||
Adds new requests and gets output tokens from the backend. | ||
|
||
:param input_data: List of input prompts. | ||
:param parameters: List of settings pertaining to each request. | ||
|
||
:return results: List of dictionaries, one for each request, that contain output tokens and other data. | ||
Get request id that will be set to backend engine request | ||
""" | ||
batch_size = len(input_data) | ||
new_requests = self.get_new_requests(input_data, parameters, | ||
batch_size) | ||
# step 0: register new requests to engine | ||
for request in new_requests: | ||
request_id = random_uuid() | ||
params = self.translate_vllm_params(request.parameters) | ||
sampling_params = SamplingParams(**params) | ||
self.engine.add_request(request_id, request.input_text, | ||
sampling_params) | ||
self.request_cache[request_id] = { | ||
"curr_length": 0, | ||
"text": "", | ||
"cumulative_logprob": 0.0, | ||
"log_prob": 0.0, | ||
"finished": False, | ||
"finish_reason": None | ||
} | ||
request_outputs = self.engine.step() | ||
# step 1: put result to cache | ||
for request_output in request_outputs: | ||
req_id = request_output.request_id | ||
self.request_cache[req_id]["id"] = request_output.outputs[ | ||
0].token_ids[-1] | ||
self.request_cache[req_id]["text"] = request_output.outputs[0].text | ||
# calculate log_prob of the token based on the diff between two cumulative log probs | ||
self.request_cache[req_id]["log_prob"] = request_output.outputs[ | ||
0].cumulative_logprob - self.request_cache[req_id][ | ||
"cumulative_logprob"] | ||
self.request_cache[req_id][ | ||
"cumulative_logprob"] = request_output.outputs[ | ||
0].cumulative_logprob | ||
self.request_cache[req_id][ | ||
"finish_reason"] = request_output.outputs[0].finish_reason | ||
if len(request_output.outputs) > 1: | ||
logging.warning( | ||
f"Finding more than 1 output for single request {len(request_output.outputs)}" | ||
f"Beam search is not supported yet, use first output by default" | ||
) | ||
self.request_cache[req_id]["finished"] = request_output.finished | ||
# Record SD metrics | ||
completion_output = request_output.outputs[0] | ||
if self.vllm_configs.record_acceptance_rate and request_output.finished and completion_output.acceptance_history: | ||
record = {} | ||
record["id"] = req_id | ||
if len(completion_output.acceptance_history) > 0: | ||
record["mean_acceptance"] = 1.0 * sum( | ||
completion_output.acceptance_history) / len( | ||
completion_output.acceptance_history) | ||
else: | ||
record["mean_acceptance"] = 0 | ||
record["prompt_size"] = len(request_output.prompt_token_ids) | ||
record["output_size"] = len(completion_output.token_ids) | ||
logging.info(f"Speculative Decoding {record}") | ||
# step 2: send result back | ||
finished_id = [] | ||
for (key, cache), request in zip(self.request_cache.items(), | ||
self.active_requests): | ||
finish_reason = None | ||
if cache["finished"]: | ||
finished_id.append(key) | ||
finish_reason = FINISH_REASON_MAPPER.get( | ||
cache["finish_reason"], None) | ||
text = cache["text"][cache["curr_length"]:] | ||
if len(text) > 0: | ||
# token id is not determined since there could be multiple token comes at the same time | ||
# only return the last one | ||
token = Token(cache['id'], text, cache["log_prob"]) | ||
request.set_next_token(token, self.output_formatter, | ||
cache["finished"], finish_reason) | ||
else: | ||
request.set_next_token("", self.output_formatter, | ||
cache["finished"], finish_reason) | ||
cache["curr_length"] = len(cache["text"]) | ||
|
||
# step 3: clean finished requests | ||
for key in finished_id: | ||
self.request_cache.pop(key) | ||
|
||
return self.postprocess_results() | ||
return random_uuid() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason not to just be consistent and use req.id here as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was not sure if there's a specific reason why req.id is not used, I asked internally but did not get a response so kept it as is. |
||
|
||
def preprocess_requests(self, requests): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just replace lmi-dist?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will be done soon once the wheel is built and added to the container.