Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve robustness for llm #2321

Merged
merged 15 commits into from
Dec 14, 2023
5 changes: 0 additions & 5 deletions llm/fastdeploy_llm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,6 @@ def run(infer_engine):
flag_ready_array[rank] = 1 # init done

while 1:
if serving_pid > 0 and (not is_process_running(serving_pid)):
print(
"[IMPORTANT] The serving process {} is not running, will terminate engine now.".
format(serving_pid))
break
if flag_begin_array[rank] != 1:
continue

Expand Down
85 changes: 63 additions & 22 deletions llm/fastdeploy_llm/serving/triton_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import time
import numpy as np
import functools
from collections import defaultdict
from fastdeploy_llm.serving.serving_model import ServingModel
from fastdeploy_llm.utils.logging_util import logger
from fastdeploy_llm.utils.logging_util import error_format, ErrorCode, ErrorType
from fastdeploy_llm.task import Task, BatchTask
import fastdeploy_llm as fdlm

Expand All @@ -31,22 +33,27 @@
pass


tokens_all_dict = defaultdict(list)

def stream_call_back(call_back_task, token_tuple, index, is_last_token,
sender):
out = dict()
out["result"] = token_tuple[1]
out["req_id"] = call_back_task.task_id
out["token_ids"] = [token_tuple[0]]
out['send_idx'] = index
out["is_end"] = 1 if is_last_token else 0
out["is_end"] = is_last_token
tokens_all_dict[call_back_task.task_id].append(token_tuple[1])
out_tensor = pb_utils.Tensor(
"OUT", np.array(
[json.dumps(out)], dtype=np.object_))
if is_last_token:
logger.info("Model output for req_id: {} results_all: {}".format(call_back_task.task_id, ''.join(tokens_all_dict[call_back_task.task_id])))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call_back_task.result.completion_tokens 存储了之前的所有生成结果(不包含本次的token)

所以要获取所有的tokenid就是

all_token_ids = [t[0] for t in call_back_task.result.completion_tokens]
all_strs = "".join[t[1] for t in call_back_task.result.completion_tokens]

加上当前的就是

all_token_ids.append(token_tuple[0])
all_strs += token_tupe[1]

sender[call_back_task.task_id].send(
pb_utils.InferenceResponse([out_tensor]),
flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
del sender[call_back_task.task_id]
del tokens_all_dict[call_back_task.task_id]
else:
sender[call_back_task.task_id].send(
pb_utils.InferenceResponse([out_tensor]))
Expand All @@ -68,10 +75,14 @@ def initialize(self, args):
using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
self.model_config)
if not using_decoupled:
raise pb_utils.TritonModelException(
"""the model `{}` can generate any number of responses per request,
error_type = ErrorType.Server
error_code = ErrorCode.S0001
error_info = """the model `{}` can generate any number of responses per request,
enable decoupled transaction policy in model configuration to
serve this model""".format(args["model_name"]))
serve this model""".format(args["model_name"])
error_msg = error_format.format(error_type.name, error_code.name, error_info)
logger.error(error_msg)
raise pb_utils.TritonModelException(error_msg)

parameters = self.model_config["parameters"]

Expand Down Expand Up @@ -112,10 +123,13 @@ def execute(self, requests):
if isinstance(data, list):
data = data[0]
except Exception as e:
error_type = ErrorType.Query
error_code = ErrorCode.C0000
error_info = "Cannot load json data from request, received data = {} error={}.".format(request_tensor, e)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
logger.error(error_msg)
error_res = pb_utils.InferenceResponse(
error=pb_utils.TritonError(
"Cannot load json data from request, error={}.".format(
e)))
error=pb_utils.TritonError(error_msg))
res_sender = request.get_response_sender()
res_sender.send(
error_res,
Expand All @@ -127,9 +141,13 @@ def execute(self, requests):
try:
task.from_dict(data)
except Exception as e:
error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError(
"There's error while deserializing data from request, error={}".
format(e)))
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "There's error while deserializing data from request, received data = {} error={}".format(data, e)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
logger.error(error_msg)
error_res = pb_utils.InferenceResponse(
error=pb_utils.TritonError(error_msg))
res_sender = request.get_response_sender()
res_sender.send(
error_res,
Expand All @@ -140,9 +158,13 @@ def execute(self, requests):
if task.task_id is None:
task.task_id = str(uuid.uuid4())
if task.task_id in self.response_handler:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "Task id conflict with {}.".format(task.task_id)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
logger.error(error_msg)
error_res = pb_utils.InferenceResponse(
error=pb_utils.TritonError(
"Task id conflict with {}.".format(task.task_id)))
error=pb_utils.TritonError(error_msg))
res_sender = request.get_response_sender()
res_sender.send(
error_res,
Expand All @@ -153,10 +175,13 @@ def execute(self, requests):
try:
task.check(self.config.max_dec_len)
except Exception as e:
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "There's error while checking task, task={} error={}".format(task, e)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
logger.error(error_msg)
error_res = pb_utils.InferenceResponse(
error=pb_utils.TritonError(
"There's error while checking task, error={}".format(
e)))
error=pb_utils.TritonError(error_msg))
res_sender = request.get_response_sender()
res_sender.send(
error_res,
Expand All @@ -165,9 +190,12 @@ def execute(self, requests):

# 5. check if the requests queue is full
if self.model.requests_queue.qsize() > self.config.max_queue_num:
error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError(
"The queue is full now(size={}), please wait for a while.".
format(self.model.max_queue_num)))
error_type = ErrorType.Server
error_code = ErrorCode.S0000
error_info = "The queue is full now(size={}), please wait for a while.".format(self.model.max_queue_num)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
logger.error(error_msg)
error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_msg))
res_sender = request.get_response_sender()
res_sender.send(
error_res,
Expand Down Expand Up @@ -195,10 +223,12 @@ def execute(self, requests):
try:
self.model.add_request(task)
except Exception as e:
error_res = pb_utils.InferenceResponse(
error=pb_utils.TritonError(
"There's error while inserting new request, error={}".
format(e)))
error_type = ErrorType.Query
error_code = ErrorCode.C0001
error_info = "There's error while inserting new request, task={} error={}".format(task, e)
error_msg = error_format.format(error_type.name, error_code.name, error_info)
logger.error(error_msg)
error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError(error_msg))
res_sender = request.get_response_sender()
res_sender.send(
error_res,
Expand All @@ -208,5 +238,16 @@ def execute(self, requests):

def finalize(self):
logger.info("The triton server is going to terminating...")
info_type = ErrorType.Server
info_code = ErrorCode.S0002
info_msg = error_format.format(info_type.name, info_code.name, "The triton server is going to terminating...")
logger.info(info_msg)
self.model.stop()
os.system("""
bash -c 'pids=$(ps auxww | grep -E "triton_python_backend_stub|multiprocessing.resource_tracker|engine.py" | grep -v grep | awk '"'"'{print $2}'"'"');
echo $pids;
for pid in ${pids[@]}; do
kill -9 ${pid}
done;'
""")
logger.info("The triton server is terminated, byebye.")
10 changes: 6 additions & 4 deletions llm/fastdeploy_llm/utils/launch_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@ def launch(device_ids, **kwargs: dict):
+ "launch_infer.launch, please set them and try again".format(
missing_args))

pd_cmd = "python3 -m paddle.distributed.launch --devices {} {} {}".format(
device_ids, infer_script_path, ' '.join(args))
#pd_cmd = "python3 -m paddle.distributed.launch --devices {} {} {}".format(
# device_ids, infer_script_path, ' '.join(args))
pd_cmd = "python3 {} {}".format(infer_script_path, ' '.join(args))
logger.info("Launch model with command: {}".format(pd_cmd))
logger.info("Model is initializing...")
infer_logger = open('modelmatrix/log/infer.log', 'a')
p = subprocess.Popen(
pd_cmd,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
stdout=infer_logger,
stderr=infer_logger,
preexec_fn=os.setsid)
return p
18 changes: 18 additions & 0 deletions llm/fastdeploy_llm/utils/logging_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import threading
import time
from enum import Enum
from typing import (Any, Generator, Optional, Union)
from logging.handlers import TimedRotatingFileHandler

Expand All @@ -42,6 +43,23 @@
}




error_format = """Error: Type {} Code {} Describe: {}"""

class ErrorCode(Enum):
C0000 = 0 # 客户端发送的query格式错误
C0001 = 1 # 客户端发送的query有效性校验
S0000 = 2 # 服务负载过大
S0001 = 3 # 服务没能正常启动
S0002 = 4 # 服务退出

class ErrorType(Enum):
Query = 0 # Query错误
Server = 1 # Server错误



class Logger(object):
_DEFAULT_NAME: str = 'FastDeploy'

Expand Down