Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add handler for new lmi-dist #1595

Merged
merged 9 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def get_rolling_batch_class_from_str(rolling_batch_type: str, is_mpi: bool,
elif rolling_batch_type == "lmi-dist":
from djl_python.rolling_batch.lmi_dist_rolling_batch import LmiDistRollingBatch
return LmiDistRollingBatch
elif rolling_batch_type == "lmi-dist-v2":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just replace lmi-dist?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be done soon once the wheel is built and added to the container.

from djl_python.rolling_batch.lmi_dist_v2_rolling_batch import LmiDistRollingBatch
return LmiDistRollingBatch
elif rolling_batch_type == "vllm":
from djl_python.rolling_batch.vllm_rolling_batch import VLLMRollingBatch
return VLLMRollingBatch
Expand Down Expand Up @@ -146,6 +149,7 @@ def initialize(self, properties: dict):
_rolling_batch_cls = get_rolling_batch_class_from_str(
self.hf_configs.rolling_batch.value, self.hf_configs.is_mpi,
self.model_config)
self.hf_configs.kwargs["model_config"] = self.model_config
self.rolling_batch = _rolling_batch_cls(
self.hf_configs.model_id_or_path, properties,
**self.hf_configs.kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def set_device(cls, properties):
@root_validator(skip_on_failure=True)
def construct_kwargs(cls, properties):
kwargs = properties['kwargs']

kwargs['trust_remote_code'] = properties['trust_remote_code']
if properties['low_cpu_mem_usage']:
kwargs["low_cpu_mem_usage"] = properties['low_cpu_mem_usage']
Expand Down Expand Up @@ -138,8 +137,10 @@ def construct_kwargs_quantize(cls, properties):

# TODO remove this after refactor of all handlers
# device map is not required for lmi dist and vllm
if properties['rolling_batch'] == RollingBatchEnum.lmidist or \
properties['rolling_batch'] == RollingBatchEnum.vllm:
if properties['rolling_batch'] in {
RollingBatchEnum.lmidist, RollingBatchEnum.vllm,
RollingBatchEnum.lmidist_v2
}:
return properties

if properties[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
class RollingBatchEnum(str, Enum):
vllm = "vllm"
lmidist = "lmi-dist"
## this is temporary and will replace lmidist after testing
lmidist_v2 = "lmi-dist-v2"
scheduler = "scheduler"
auto = "auto"
disable = "disable"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from pydantic.v1.class_validators import validator, root_validator

from djl_python.properties_manager.properties import Properties
from djl_python.properties_manager.properties import Properties, RollingBatchEnum


class VllmQuantizeMethods(str, Enum):
Expand Down Expand Up @@ -45,12 +45,19 @@ class VllmRbProperties(Properties):
draft_model_tp_size: int = 1
record_acceptance_rate: Optional[bool] = False

@validator('engine')
def validate_engine(cls, engine):
if engine != "Python":
@root_validator(skip_on_failure=True)
def validate_engine(cls, properties):
engine = properties["engine"]
rolling_batch = properties["rolling_batch"]
if rolling_batch == RollingBatchEnum.vllm and engine != "Python":
raise AssertionError(
f"Need python engine to start vLLM RollingBatcher")
return engine

if rolling_batch == RollingBatchEnum.lmidist_v2 and engine != "MPI":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we just have individual parameter checking for each rolling batch implementation?

raise AssertionError(
f"Need MPI engine to start lmidist_v2 RollingBatcher")

return properties

# TODO: Remove this once SageMaker resolved driver issue
@root_validator(skip_on_failure=True)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#!/usr/bin/env python
#
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.

from collections import OrderedDict

from lmi_dist.api import Request
from lmi_dist.init_engine import engine_from_args
from vllm import EngineArgs, SamplingParams
rohithkrn marked this conversation as resolved.
Show resolved Hide resolved

from djl_python.rolling_batch.vllm_rolling_batch_base import VllmRollingBatchBase, DTYPE_MAPPER
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties


class LmiDistRollingBatch(VllmRollingBatchBase):
rohithkrn marked this conversation as resolved.
Show resolved Hide resolved
"""
LmiDistRollingBatch connects handler to LmiDist backend engine. It receives new
requests from the handler and sends them to the backend when space is available in the batch.
It also gets any new tokens from the backend and sends them back to the handler.
"""

def __init__(self, model_id_or_path: str, properties: dict, **kwargs):
"""
Initializes the LmiDistRollingBatch.

:param model_id_or_path (str): Currently unused since there is a copy inside properties
:param properties (dict): other properties of the model, such as decoder strategy
"""
engine_config = VllmRbProperties(**properties)
super().__init__(engine_config, kwargs.get("model_config", None))
self.init_engine()

def init_engine(self):
"""
Initializes vllm engine
"""
args = EngineArgs(
model=self.engine_config.model_id_or_path,
tensor_parallel_size=self.engine_config.tensor_parallel_degree,
dtype=DTYPE_MAPPER[self.engine_config.dtype],
seed=0,
max_model_len=self.engine_config.max_model_len,
enforce_eager=self.engine_config.enforce_eager,
gpu_memory_utilization=self.engine_config.gpu_memory_utilization,
max_num_batched_tokens=self.engine_config.
max_rolling_batch_prefill_tokens,
trust_remote_code=self.engine_config.trust_remote_code,
load_format=self.engine_config.load_format,
quantization=self.engine_config.quantize,
revision=self.engine_config.revision)
self.engine = engine_from_args(args)

def reset(self) -> None:
"""
Aborts all requests
"""
self.engine.reset(self.request_cache.keys())
self.request_cache = OrderedDict()
super().reset()

def add_request(self, request_id: str, prompt: str,
sampling_params: SamplingParams):
"""
Adds request to the engine
"""
lmi_dist_request = Request(id=request_id,
prompt=prompt,
sampling_params=sampling_params)
self.engine.add_request(lmi_dist_request)

def translate_to_engine_params(self, parameters: dict):
"""
Helper function to convert DJL Serving parameter names to parameter names
that lmidist_v2 recognizes.

:param parameters (dict): Parameters pertaining to a specific request

:return: The same parameters dict, but with VLLM style parameter names.
"""
parameters.pop('seed', None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seed is supported by vllm now: vllm-project/vllm#2514

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, will remove. In this PR, I kept what vllm rolling batch does currently and wanted to tune params in the next PR.

parameters.pop('do_sample', None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't do_sample=False map to temperature=0, basically? vllm does support greedy, it just uses temperature to accomplish that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said above, I used vllm config here. But it's a good point, I believe this is removed for vllm because it doesn't support do_sample parameter whereas lmi-dist should (for backwards compatibility) and we should set default sampling params

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can punt on this, but I think this is another point to bring up about how consistent of an interface we want to provide across engine/backend, vs. how closely the interface should change to match each engine/backend.

if "max_new_tokens" in parameters.keys():
parameters["max_tokens"] = parameters.pop("max_new_tokens")
if "stop_sequences" in parameters.keys():
parameters["stop"] = parameters.pop("stop_sequences")
if "ignore_eos_token" in parameters.keys():
parameters["ignore_eos"] = parameters.pop("ignore_eos")
return parameters

def get_request_id(self, request):
"""
Get request id that will be set to backend engine request
"""
return str(request.id)

def preprocess_requests(self, requests):
"""
Currently not applicable for VLLM.
"""
raise NotImplementedError(
"Not implemented for lmidist_v2 rolling batcher")
159 changes: 35 additions & 124 deletions engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,16 @@
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import json
import logging
from collections import OrderedDict

from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.utils import random_uuid
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, Token

from djl_python.rolling_batch.vllm_rolling_batch_base import VllmRollingBatchBase, DTYPE_MAPPER
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties

DTYPE_MAPPER = {
"fp32": "float32",
"fp16": "float16",
"bf16": "bfloat16",
"auto": "auto"
}

FINISH_REASON_MAPPER = {
"length": "length",
"stop": "eos_token",
"abort": "abort"
}


class VLLMRollingBatch(RollingBatch):
class VLLMRollingBatch(VllmRollingBatchBase):
"""
VLLMRollingBatch connects the handler to the backend (VLLM inference). It receives new
requests from the handler and sends them to the backend when there is space available in the batch.
Expand All @@ -50,28 +35,32 @@ def __init__(self, model_id_or_path: str, properties: dict,
:param model_id_or_path: Currently unused since there is a copy inside properties
:param properties: other properties of the model, such as decoder strategy
"""
self.vllm_configs = VllmRbProperties(**properties)
super().__init__(waiting_steps=self.vllm_configs.waiting_steps,
output_formatter=self.vllm_configs.output_formatter)
engine_config = VllmRbProperties(**properties)
super().__init__(engine_config, kwargs.get("model_config", None))
self.init_engine()

def init_engine(self):
"""
Initializes vllm engine
"""
args = EngineArgs(
model=self.vllm_configs.model_id_or_path,
tensor_parallel_size=self.vllm_configs.tensor_parallel_degree,
dtype=DTYPE_MAPPER[self.vllm_configs.dtype],
model=self.engine_config.model_id_or_path,
tensor_parallel_size=self.engine_config.tensor_parallel_degree,
dtype=DTYPE_MAPPER[self.engine_config.dtype],
seed=0,
max_model_len=self.vllm_configs.max_model_len,
enforce_eager=self.vllm_configs.enforce_eager,
gpu_memory_utilization=self.vllm_configs.gpu_memory_utilization,
max_num_batched_tokens=self.vllm_configs.
max_model_len=self.engine_config.max_model_len,
enforce_eager=self.engine_config.enforce_eager,
gpu_memory_utilization=self.engine_config.gpu_memory_utilization,
max_num_batched_tokens=self.engine_config.
max_rolling_batch_prefill_tokens,
trust_remote_code=self.vllm_configs.trust_remote_code,
load_format=self.vllm_configs.load_format,
quantization=self.vllm_configs.quantize,
draft_model=self.vllm_configs.speculative_draft_model,
speculate_length=self.vllm_configs.speculative_length,
draft_model_tp_size=self.vllm_configs.draft_model_tp_size,
revision=self.vllm_configs.revision)
trust_remote_code=self.engine_config.trust_remote_code,
load_format=self.engine_config.load_format,
quantization=self.engine_config.quantize,
draft_model=self.engine_config.speculative_draft_model,
speculate_length=self.engine_config.speculative_length,
draft_model_tp_size=self.engine_config.draft_model_tp_size,
revision=self.engine_config.revision)
self.engine = LLMEngine.from_engine_args(args)
self.request_cache = OrderedDict()

def reset(self) -> None:
"""
Expand All @@ -82,7 +71,14 @@ def reset(self) -> None:
self.request_cache = OrderedDict()
super().reset()

def translate_vllm_params(self, parameters: dict) -> dict:
def add_request(self, request_id: str, prompt: str,
sampling_params: SamplingParams):
"""
Adds request to the engine
"""
self.engine.add_request(request_id, prompt, sampling_params)

def translate_to_engine_params(self, parameters: dict) -> dict:
"""
Helper function to convert DJL Serving parameter names to parameter names
that VLLM recognizes.
Expand All @@ -101,96 +97,11 @@ def translate_vllm_params(self, parameters: dict) -> dict:
parameters["ignore_eos"] = parameters.pop("ignore_eos")
return parameters

@stop_on_any_exception
def inference(self, input_data: list[str], parameters: list[dict]) -> list:
def get_request_id(self, request):
"""
Adds new requests and gets output tokens from the backend.

:param input_data: List of input prompts.
:param parameters: List of settings pertaining to each request.

:return results: List of dictionaries, one for each request, that contain output tokens and other data.
Get request id that will be set to backend engine request
"""
batch_size = len(input_data)
new_requests = self.get_new_requests(input_data, parameters,
batch_size)
# step 0: register new requests to engine
for request in new_requests:
request_id = random_uuid()
params = self.translate_vllm_params(request.parameters)
sampling_params = SamplingParams(**params)
self.engine.add_request(request_id, request.input_text,
sampling_params)
self.request_cache[request_id] = {
"curr_length": 0,
"text": "",
"cumulative_logprob": 0.0,
"log_prob": 0.0,
"finished": False,
"finish_reason": None
}
request_outputs = self.engine.step()
# step 1: put result to cache
for request_output in request_outputs:
req_id = request_output.request_id
self.request_cache[req_id]["id"] = request_output.outputs[
0].token_ids[-1]
self.request_cache[req_id]["text"] = request_output.outputs[0].text
# calculate log_prob of the token based on the diff between two cumulative log probs
self.request_cache[req_id]["log_prob"] = request_output.outputs[
0].cumulative_logprob - self.request_cache[req_id][
"cumulative_logprob"]
self.request_cache[req_id][
"cumulative_logprob"] = request_output.outputs[
0].cumulative_logprob
self.request_cache[req_id][
"finish_reason"] = request_output.outputs[0].finish_reason
if len(request_output.outputs) > 1:
logging.warning(
f"Finding more than 1 output for single request {len(request_output.outputs)}"
f"Beam search is not supported yet, use first output by default"
)
self.request_cache[req_id]["finished"] = request_output.finished
# Record SD metrics
completion_output = request_output.outputs[0]
if self.vllm_configs.record_acceptance_rate and request_output.finished and completion_output.acceptance_history:
record = {}
record["id"] = req_id
if len(completion_output.acceptance_history) > 0:
record["mean_acceptance"] = 1.0 * sum(
completion_output.acceptance_history) / len(
completion_output.acceptance_history)
else:
record["mean_acceptance"] = 0
record["prompt_size"] = len(request_output.prompt_token_ids)
record["output_size"] = len(completion_output.token_ids)
logging.info(f"Speculative Decoding {record}")
# step 2: send result back
finished_id = []
for (key, cache), request in zip(self.request_cache.items(),
self.active_requests):
finish_reason = None
if cache["finished"]:
finished_id.append(key)
finish_reason = FINISH_REASON_MAPPER.get(
cache["finish_reason"], None)
text = cache["text"][cache["curr_length"]:]
if len(text) > 0:
# token id is not determined since there could be multiple token comes at the same time
# only return the last one
token = Token(cache['id'], text, cache["log_prob"])
request.set_next_token(token, self.output_formatter,
cache["finished"], finish_reason)
else:
request.set_next_token("", self.output_formatter,
cache["finished"], finish_reason)
cache["curr_length"] = len(cache["text"])

# step 3: clean finished requests
for key in finished_id:
self.request_cache.pop(key)

return self.postprocess_results()
return random_uuid()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason not to just be consistent and use req.id here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not sure if there's a specific reason why req.id is not used, I asked internally but did not get a response so kept it as is.


def preprocess_requests(self, requests):
"""
Expand Down
Loading
Loading