From 8c588c72c76937379760d464e398253221e4af59 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 19 Sep 2024 11:21:27 +0800 Subject: [PATCH] refine llm infer --- llm/benchmark/analyse.py | 279 +++++++++++++++++ llm/benchmark/benchmark.py | 280 ++++++++++++++++++ llm/benchmark/logger.py | 160 ++++++++++ .../Dockerfile_serving_cuda118_cudnn8 | 3 +- .../Dockerfile_serving_cuda123_cudnn9 | 3 +- llm/server/requirements.txt | 1 - llm/server/scripts/start_server.sh | 15 +- llm/server/server/engine/infer.py | 3 +- llm/server/server/triton_server.py | 17 -- llm/server/server/utils.py | 7 +- .../trition_server_model/model/1/model.py | 1 + .../model}/config.pbtxt | 0 12 files changed, 742 insertions(+), 27 deletions(-) create mode 100644 llm/benchmark/analyse.py create mode 100644 llm/benchmark/benchmark.py create mode 100644 llm/benchmark/logger.py create mode 100644 llm/server/trition_server_model/model/1/model.py rename llm/server/{config => trition_server_model/model}/config.pbtxt (100%) diff --git a/llm/benchmark/analyse.py b/llm/benchmark/analyse.py new file mode 100644 index 0000000000..63634bd191 --- /dev/null +++ b/llm/benchmark/analyse.py @@ -0,0 +1,279 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import sys +from dataclasses import dataclass, field +from datetime import datetime + +import numpy as np + + +@dataclass +class Resp: + is_valid: bool = False + + req_id: str = None + max_dec_len: int = None + min_dec_len: int = None + max_send_idx: int = None + + input_token_num: int = None + output_token_num: int = None + is_end: bool = None + + send_req_time: float = None + first_token_end2end_time: float = None + all_token_end2end_time: float = None + first_token_infer_time: float = None + all_token_infer_time: float = None + + http_received_cost_time: float = 0 + infer_received_cost_time: float = 0 + tokenizer_encode_cost_time: float = 0 + tokenizer_decode_cost_time: float = 0 + preprocess_cost_time: float = 0 + pending_cost_time: float = 0 + get_image_cost_time: float = 0 + process_image_cost_time: float = 0 + + input_text: str = None + output_list: list = field(default_factory=list) + + error_msg: str = "" + exception_msg: str = "" + + def auto_set_valid(self): + self.is_valid = True + names = ["req_id", "max_dec_len", "min_dec_len", "max_send_idx", "is_end", + "output_token_num", "send_req_time", "first_token_end2end_time", + "all_token_end2end_time", "first_token_infer_time", "all_token_infer_time"] + for name in names: + if getattr(self, name) is None: + self.is_valid = False + if self.error_msg != "" or self.exception_msg != "": + self.is_valid = False + + def is_error(self) -> bool: + return self.error_msg != "" + + def is_exception(self) -> bool: + return self.exception_msg != "" + + +def str_to_datetime(date_string): + if "." in date_string: + return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S.%f") + else: + return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S") + +def datetime_diff(datetime_start, datetime_end): + if isinstance(datetime_start, str): + datetime_start = str_to_datetime(datetime_start) + if isinstance(datetime_end, str): + datetime_end = str_to_datetime(datetime_end) + if datetime_end > datetime_start: + cost = datetime_end - datetime_start + else: + cost = datetime_start - datetime_end + return cost.total_seconds() + +def pp_print(name, input_list): + out_str = f"{name:<35}" + for item in input_list: + out_str += f"{item:<15}" + print(out_str) + +def pp_print_md(name, lst): + info = f"| {name:<35} |" + for i in lst: + info += f" {i:<15} |" + + info = f"| {name:<35} | " + print(info) + + +def collect_response(input_path): + result_dict = {} + start_time = None + end_time = None + log_step = 100000 + + print("\nstart read and collect response...") + with open(input_path, 'r', encoding='utf-8') as file: + for idx, line in enumerate(file): + try: + item = json.loads(line.rstrip('\n')) + except Exception as e: + print(f"error when parse line. idx: {idx}, line: {line} error:{e}") + + item_type = item['type'] + assert item_type in ["request", "response", "error", "exception"] + + req_id = item['content']['req_id'] + if req_id in result_dict: + resp = result_dict[req_id] + if resp.is_valid: + print("error: the req_id is already in result_dict") + continue + else: + resp = Resp(req_id=req_id) + result_dict[req_id] = resp + + if item_type == "request": + resp.max_dec_len = item['content']["max_dec_len"] + resp.min_dec_len = item['content']["min_dec_len"] + resp.input_text = item['content']["text"] + resp.send_req_time = str_to_datetime(item["now_time"]) + elif item_type == "response": + content = item['content'] + if content["send_idx"] == 0: + resp.input_token_num = content.get("input_ids_len", 0) + if content.get("http_received_time"): + resp.http_received_cost_time = datetime_diff(resp.send_req_time, content.get("http_received_time")) + if content.get("preprocess_start_time"): + resp.infer_received_cost_time = datetime_diff(resp.send_req_time, content.get("preprocess_start_time")) + + resp.first_token_infer_time = content["inference_time_cost"] + if content.get("preprocess_start_time") and content.get("preprocess_end_time"): + resp.preprocess_cost_time = datetime_diff(content.get("preprocess_start_time"), + content.get("preprocess_end_time")) + if content.get("preprocess_end_time") and content.get("schedule_start_time"): + resp.pending_cost_time = datetime_diff(content.get("preprocess_end_time"), + content.get("schedule_start_time")) + resp.get_image_cost_time = content.get("get_image_cost_time", 0) + resp.process_image_cost_time = content.get("process_image_cost_time", 0) + resp.tokenizer_encode_cost_time = content.get("tokenizer_encode_cost_time", 0) + resp.first_token_end2end_time = datetime_diff(resp.send_req_time, item["now_time"]) + if content["is_end"] == 1: + resp.is_end = True + resp.max_send_idx = content["send_idx"] + resp.output_token_num = content["tokens_all_num"] + resp.all_token_end2end_time = datetime_diff(resp.send_req_time, item["now_time"]) + resp.all_token_infer_time = content["inference_time_cost"] + resp.auto_set_valid() + resp.output_list.append({'idx': int(content['send_idx']), 'token':content['token']}) + resp.tokenizer_decode_cost_time += content.get("tokenizer_decode_cost_time", 0) + elif item_type == "error": + resp.error_msg += item['content']["error_msg"] + elif item_type == "exception": + resp.exception_msg += item['content']["exception_msg"] + + now_time = str_to_datetime(item["now_time"]) + if start_time is None: + start_time = resp.send_req_time + if end_time is None: + end_time = now_time + elif end_time < now_time: + end_time = now_time + + if idx % log_step == 0: + print(f"read {idx+1} chunks", end=', ', flush=True) + + result_list = result_dict.values() + cost_time = datetime_diff(start_time, end_time) + print(f"\nstart_time: {start_time}, end_time: {end_time}, " + f"cost_time: {cost_time}, result_list_num: {len(result_list)}") + return result_list, cost_time + +def save_output_text(result_list, input_path): + output_path = input_path.replace(".jsonl", "-out_msg.jsonl") + with open(output_path, "w", encoding='utf-8') as out_file: + for result in result_list: + if result.is_valid: + output_list = sorted(result.output_list, key=lambda d: d['idx']) + output_text = "" + for i in output_list: + output_text += i['token'] + dict_obj = {'req_id': result.req_id, 'input_text': result.input_text, 'output_text': output_text} + out_file.write(json.dumps(dict_obj, ensure_ascii=False) + "\n") + print(f"output save in {output_path}") + + +def stats_and_percentiles(lst, round_bit=3, multi=1): + lst = [item * multi for item in lst] + num = len(lst) + max_val = round(max(lst), round_bit) + min_val = round(min(lst), round_bit) + avg_val = round(sum(lst) / len(lst), round_bit) + + pct_50, pct_80, pct_95, pct_99 = np.percentile(lst, [50, 80, 95, 99]) + pct_50 = round(pct_50, round_bit) + pct_80 = round(pct_80, round_bit) + pct_95 = round(pct_95, round_bit) + pct_99 = round(pct_99, round_bit) + + return {"num": num, "max": max_val, "min": min_val, "avg": avg_val, + "pct_50": pct_50, "pct_80": pct_80, "pct_95": pct_95, "pct_99": pct_99} + +def analyse_single_key(result_list, key_name, round_bit=2, multi=1): + key_list = [] + for resp in result_list: + if not resp.is_valid: + continue + key_list.append(resp.__dict__[key_name]) + + return stats_and_percentiles(key_list, round_bit, multi) + +def analyse_response(result_list, cost_time): + print("\nstart anaylse response...") + valid_resp_num = 0 + error_num = 0 + exception_num = 0 + for resp in result_list: + if resp.is_valid: + valid_resp_num += 1 + elif resp.is_error(): + error_num += 1 + print(f"error resp: {resp}") + elif resp.is_exception(): + exception_num += 1 + print(f"exception resp: {resp}") + + print(f"total response num: {len(result_list)}, valid response num: {valid_resp_num}, " + f"error_num: {error_num}, exception_num: {exception_num}") + print(f"qps: {round(valid_resp_num / cost_time, 2)} \n") + + info_list = [{'key': 'output_token_num', 'multi': 1, 'msg': '生成token数'}, + {'key': 'first_token_infer_time', 'multi': 1000, 'msg': '首token推理耗时(ms)'}, + {'key': 'all_token_infer_time', 'multi': 1000, 'msg': '整句推理耗时(ms)'}, + {'key': 'first_token_end2end_time', 'multi': 1000, 'msg': '首token用户侧耗时(ms)'}, + {'key': 'all_token_end2end_time', 'multi': 1000, 'msg': '整句用户侧耗时(ms)'}, + {'key': 'infer_received_cost_time', 'multi': 1000, 'msg': '推理收到请求耗时(ms)'}, + {'key': 'http_received_cost_time', 'multi': 1000, 'msg': 'http收到请求耗时(ms)'}, + {'key': 'preprocess_cost_time', 'multi': 1000, 'msg': '预处理耗时(ms)'}, + {'key': 'pending_cost_time', 'multi': 1000, 'msg': '缓存等待推理耗时(ms)'}, + ] + print("| 指标 | 样本数 | 最大 | 最小 | 平均 | 50% | 80% | 95% | 99% |") + print("| ---- | ---- | ---- | ----| ---- | ---- | ---- | ---- | ---- |") + for info in info_list: + out = analyse_single_key(result_list, info['key'], multi=info['multi']) + print(f"| {info['msg']:<35} | {out['num']:<15} | {out['max']:<15} | {out['min']:<15} | {out['avg']:<15} " + f"| {out['pct_50']:<15} | {out['pct_80']:<15} | {out['pct_95']:<15} | {out['pct_99']:<15} |") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_path", type=str, help="the jsonl result file generated by run_benchmark_xx.py") + args = parser.parse_args() + return args + +if __name__ == '__main__': + args = parse_args() + print(f"input_path: {args.input_path}") + + result_list, cost_time = collect_response(args.input_path) + analyse_response(result_list, cost_time) + save_output_text(result_list, args.input_path) diff --git a/llm/benchmark/benchmark.py b/llm/benchmark/benchmark.py new file mode 100644 index 0000000000..994c342f4d --- /dev/null +++ b/llm/benchmark/benchmark.py @@ -0,0 +1,280 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import sys +import queue +import threading +import time +import uuid +from dataclasses import asdict, dataclass +from datetime import datetime +from functools import partial + +import httpx +import numpy as np + +import tritonclient.grpc as grpcclient +from tritonclient.utils import * + +from logger import get_logger + +def _http_send_worker(args, req_dict, result_queue): + is_error_resp = False + headers = {'Content-Type': 'application/json'} + with httpx.stream("POST", args.url, headers=headers, timeout=args.timeout, json=req_dict) as r: + for chunk in r.iter_lines(): + resp = json.loads(chunk) + if resp.get("error_msg") or resp.get("error_code"): + is_error_resp = True + content = {"error_msg": resp.get("error_msg"), "req_id": req_dict.get("req_id")} + result_queue.put({"type": "error", "now_time": str(datetime.now()), "content": content}) + else: + result_queue.put({"type": "response", "now_time": str(datetime.now()), "content": resp}) + return is_error_resp + +def _grpc_send_worker(args, req_dict, result_queue): + class OutputData: + def __init__(self): + self._completed_requests = queue.Queue() + + def triton_callback(output_data, result, error): + if error: + output_data._completed_requests.put(error) + else: + output_data._completed_requests.put(result) + + model_name = "model" + inputs = [grpcclient.InferInput("IN", [1], np_to_triton_dtype(np.object_))] + outputs = [grpcclient.InferRequestedOutput("OUT")] + output_data = OutputData() + is_error_resp = False + + with grpcclient.InferenceServerClient(url=args.url, verbose=False) as triton_client: + triton_client.start_stream(callback=partial(triton_callback, output_data)) + + input_data = json.dumps([req_dict]) + inputs[0].set_data_from_numpy(np.array([input_data], dtype=np.object_)) + + triton_client.async_stream_infer(model_name=model_name, + inputs=inputs, + request_id=str(uuid.uuid4()), + outputs=outputs) + + while True: + output_item = output_data._completed_requests.get(timeout=args.timeout) + if type(output_item) == InferenceServerException: + is_error_resp = True + error_msg = f"Exception: status is {output_item.status()}, msg is {output_item.message()}" + content = {"error_msg": error_msg, "req_id": req_dict.get("req_id")} + result_queue.put({"type": "error", "now_time": str(datetime.now()), "content": content}) + else: + result = json.loads(output_item.as_numpy("OUT")[0]) + result = result[0] if isinstance(result, list) else result + result_queue.put({"type": "response", "now_time": str(datetime.now()), "content": result}) + if result.get("is_end") == 1: + break + return is_error_resp + +def send_worker(args, data_queue, result_queue, worker_idx, logger): + """ + send requests and put response into result_queue + """ + logger.info(f"[send_worker {worker_idx}] start...") + + cur_idx = 0 + exception_num = 0 + exception_threshold = 10 + error_resp_num = 0 + log_step = 10 + + while not data_queue.empty(): + # read data + try: + input_data = data_queue.get(timeout=3) + remaining_num = data_queue.qsize() + cur_idx += 1 + except queue.Empty: + logger.info(f"[send_worker {worker_idx}] data queue is empty") + break + except Exception as e: + exception_num += 1 + logger.error(f"[send_worker {worker_idx}][fd_error] fetch data error: {e}") + continue + + result_queue.put({"type": "request", "now_time": str(datetime.now()), "content": input_data}) + + # send request + try: + if args.api_type == 'http': + is_error_resp = _http_send_worker(args, input_data, result_queue) + elif args.api_type == 'grpc': + is_error_resp = _grpc_send_worker(args, input_data, result_queue) + error_resp_num += 1 if is_error_resp else 0 + except Exception as e: + exception_num += 1 + content = {"exception_msg": str(e), "req_id": input_data.get("req_id")} + result_queue.put({"type": "exception", "now_time": str(datetime.now()), "content": content}) + if exception_num > exception_threshold: + logger.error(f"[send_worker {worker_idx}] exception num ({exception_num}) exceeds " + f"threshold, exit") + break + + # log + if cur_idx % log_step == 1: + logger.info(f"[send_worker {worker_idx}] processed_num: {cur_idx}, exception_num: {exception_num}, " + f"error_resp_num: {error_resp_num}, data queue remaining ({remaining_num}) tasks") + + logger.info(f"[send_worker {worker_idx}] exit, processed_num: {cur_idx}, exception_num: {exception_num}, " + f"error_resp_num: {error_resp_num}") + +def save_worker(result_path, result_queue, logger, timeout=50, log_step=10000): + """ + save the result to file + """ + logger.info("[save_worker] start...") + num = 0 + with open(result_path, "w", encoding='utf-8') as out_file: + while True: + try: + res_chunk = result_queue.get(timeout=timeout) + except queue.Empty: + logger.info("[save_worker] result queue is empty") + break + except Exception as e: + logger.error(f"[save_worker] Error retrieving data from queue: {e}") + break + + json_str = json.dumps(res_chunk, ensure_ascii=False) + out_file.write(json_str + "\n") + num += 1 + if num % log_step == 0: + logger.info(f"[save_worker] process {num} response chunks") + + logger.info("[save_worker] exit") + +def prepare_data(data_path, data_num, benchmark=True, stream=True, timeout=180): + """ + prepare data + """ + ''' + data_queue = queue.Queue() + with open(data_path, 'r', encoding='utf-8') as file: + for idx, line in enumerate(file): + raw_data = json.loads(line.rstrip('\n')) + input_data = { + "text": raw_data['text_before_process'], + "max_dec_len": raw_data["max_dec_len"], + "min_dec_len": raw_data["min_dec_len"], + "topp": raw_data["topp"], + "temperature": raw_data["temperature"], + "frequency_score": raw_data["frequency_score"], + "penalty_score": raw_data["penalty_score"], + "presence_score": raw_data["presence_score"], + "req_id": str(uuid.uuid4()), + "stream": stream, + "benchmark": benchmark, + "timeout": timeout, + } + if raw_data["history_QA"] != []: + input_data["history_qa"] = raw_data["history_QA"] + + data_queue.put(input_data) + if data_num > 0 and idx + 1 >= data_num: + break + return data_queue + ''' + data_queue = queue.Queue() + with open(data_path, 'r', encoding='utf-8') as file: + dataset = json.load(file) + dataset = [data for data in dataset if len(data['conversations']) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data['conversations'][0]['value'], + data['conversations'][1]['value']) for data in dataset] + prompts = [prompt for prompt, _ in dataset] + + for idx, text in enumerate(prompts): + input_data = { + "text": text, + "max_dec_len": 1024, + "min_dec_len": 1, + "topp": 0, + "temperature": 1, + "req_id": str(uuid.uuid4()), + "stream": stream, + "benchmark": benchmark, + "timeout": timeout, + } + data_queue.put(input_data) + if data_num > 0 and idx + 1 >= data_num: + break + return data_queue + + +def parse_args(): + """ + parse the arguments + """ + parser = argparse.ArgumentParser() + parser.add_argument("--api_type", default="http", type=str, help="grpc or http api") + parser.add_argument("--url", default="http://0.0.0.0:8894/v1/chat/completions", type=str, help="the url for model server") + parser.add_argument("--data_path", default="data.jsonl", type=str, help="the path of data with jsonl format") + parser.add_argument("--data_num", default=-1, type=int, help="-1 means all data") + parser.add_argument("--timeout", default=180, type=int, help="timeout for waiting repsonse") + parser.add_argument("--worker_num", default=1, type=int, help="the number of worker_num for sending requests") + parser.add_argument("--tag", default="test", type=str, help="identify the test case") + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_args() + + # prepare + data_queue = prepare_data(args.data_path, args.data_num, benchmark=False, timeout=args.timeout) + if args.data_num < 0: + args.data_num = data_queue.qsize() + print(f"data_queue size: {data_queue.qsize()}") + + test_tag = f"{args.tag}-{args.api_type}-wk{args.worker_num}-dn{args.data_num}" + logger = get_logger('benchmark', f'{test_tag}-log') + logger.info(f"args: {args}") + logger.info(f"test_tag: {test_tag}") + + result_path = f"output/{test_tag}.jsonl" + if os.path.exists(result_path): + logger.error(f"result file ({result_path}) already exists, exit") + exit() + if not os.path.exists("output/"): + os.makedirs("output/") + logger.info(f"result_path: {result_path}") + + # save worker + worker_list = [] + result_queue = queue.Queue() + worker = threading.Thread(target=save_worker, args=(result_path, result_queue, logger, 20)) + worker.start() + worker_list.append(worker) + + # send worker + tic = time.time() + for idx in range(args.worker_num): + worker = threading.Thread(target=send_worker, args=(args, data_queue, result_queue, idx, logger)) + worker.start() + worker_list.append(worker) + for worker in worker_list: + worker.join() + + toc = time.time() + logger.info(f'Done, cost time: {round(toc - tic, 2)}s') diff --git a/llm/benchmark/logger.py b/llm/benchmark/logger.py new file mode 100644 index 0000000000..3db94344d8 --- /dev/null +++ b/llm/benchmark/logger.py @@ -0,0 +1,160 @@ + +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import codecs +import logging +import os +import pickle +import re +import subprocess +import time +from datetime import datetime +from enum import Enum +from logging.handlers import BaseRotatingHandler +from pathlib import Path + + +class DailyRotatingFileHandler(BaseRotatingHandler): + """ + like `logging.TimedRotatingFileHandler`, but this class support multi-process + """ + + def __init__( + self, + filename, + backupCount=0, + encoding="utf-8", + delay=False, + utc=False, + **kwargs + ): + self.backup_count = backupCount + self.utc = utc + self.suffix = "%Y-%m-%d" + self.base_log_path = Path(filename) + self.base_filename = self.base_log_path.name + self.current_filename = self._compute_fn() + self.current_log_path = self.base_log_path.with_name(self.current_filename) + BaseRotatingHandler.__init__(self, filename, "a", encoding, delay) + + def shouldRollover(self, record): + """ + check scroll through the log + """ + if self.current_filename != self._compute_fn(): + return True + return False + + def doRollover(self): + """ + scroll log + """ + if self.stream: + self.stream.close() + self.stream = None + + self.current_filename = self._compute_fn() + self.current_log_path = self.base_log_path.with_name(self.current_filename) + + if not self.delay: + self.stream = self._open() + + self.delete_expired_files() + + def _compute_fn(self): + """ + Calculate the log file name corresponding current time + """ + return self.base_filename + "." + time.strftime(self.suffix, time.localtime()) + + def _open(self): + """ + open new log file + """ + if self.encoding is None: + stream = open(str(self.current_log_path), self.mode) + else: + stream = codecs.open(str(self.current_log_path), self.mode, self.encoding) + + if self.base_log_path.exists(): + try: + if ( + not self.base_log_path.is_symlink() + or os.readlink(self.base_log_path) != self.current_filename + ): + os.remove(self.base_log_path) + except OSError: + pass + + try: + os.symlink(self.current_filename, str(self.base_log_path)) + except OSError: + pass + return stream + + def delete_expired_files(self): + """ + delete expired log files + """ + if self.backup_count <= 0: + return + + file_names = os.listdir(str(self.base_log_path.parent)) + result = [] + prefix = self.base_filename + "." + plen = len(prefix) + for file_name in file_names: + if file_name[:plen] == prefix: + suffix = file_name[plen:] + if re.match(r"^\d{4}-\d{2}-\d{2}(\.\w+)?$", suffix): + result.append(file_name) + if len(result) < self.backup_count: + result = [] + else: + result.sort() + result = result[: len(result) - self.backup_count] + + for file_name in result: + os.remove(str(self.base_log_path.with_name(file_name))) + + +def get_logger(name, file_name=None): + """ + 获取logger + """ + if file_name is None: + file_name = name + ".log" + log_dir = os.getenv("log_dir", default="log") + if not os.path.exists(log_dir): + os.mkdir(log_dir) + + logger = logging.getLogger(name) + is_debug = int(os.getenv("FD_DEBUG", default=0)) + if is_debug: + logger.setLevel(level=logging.DEBUG) + else: + logger.setLevel(level=logging.INFO) + + log_file = "{0}/{1}".format(log_dir, file_name) + handler = DailyRotatingFileHandler(log_file, backupCount=7) + + formatter = logging.Formatter( + "%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s" + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + handler.propagate = False + return logger diff --git a/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8 b/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8 index 3f5a2a511f..1fe513464f 100644 --- a/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8 +++ b/llm/dockerfiles/Dockerfile_serving_cuda118_cudnn8 @@ -30,7 +30,6 @@ RUN cd /opt/output/Serving/ \ && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \ && rm -rf scripts -RUN python3 -m pip install protobuf==3.20.0 - ENV http_proxy "" ENV https_proxy "" +ENV TZ=Asia/Shanghai diff --git a/llm/dockerfiles/Dockerfile_serving_cuda123_cudnn9 b/llm/dockerfiles/Dockerfile_serving_cuda123_cudnn9 index 786bf2aecd..e41cf9770f 100644 --- a/llm/dockerfiles/Dockerfile_serving_cuda123_cudnn9 +++ b/llm/dockerfiles/Dockerfile_serving_cuda123_cudnn9 @@ -30,7 +30,6 @@ RUN cd /opt/output/Serving/ \ && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \ && rm -rf scripts -RUN python3 -m pip install protobuf==3.20.0 - ENV http_proxy "" ENV https_proxy "" +ENV TZ=Asia/Shanghai diff --git a/llm/server/requirements.txt b/llm/server/requirements.txt index cc65a67266..1056c238c8 100644 --- a/llm/server/requirements.txt +++ b/llm/server/requirements.txt @@ -1,5 +1,4 @@ # model server -paddlenlp==2.7.2 sentencepiece pycryptodome tritonclient[all]==2.41.1 diff --git a/llm/server/scripts/start_server.sh b/llm/server/scripts/start_server.sh index 43ef7cb3c7..5735310f00 100644 --- a/llm/server/scripts/start_server.sh +++ b/llm/server/scripts/start_server.sh @@ -40,6 +40,19 @@ export METRICS_PORT=${METRICS_PORT:-"8722"} export INFER_QUEUE_PORT=${INFER_QUEUE_PORT:-"8813"} export PUSH_MODE_HTTP_PORT=${PUSH_MODE_HTTP_PORT:-"9965"} +ports=(${HTTP_PORT} ${GRPC_PORT} ${METRICS_PORT} ${INFER_QUEUE_PORT} ${PUSH_MODE_HTTP_PORT}) +for port in "${ports[@]}"; do + output=$(netstat -tuln | grep ":${port} ") + if [ -n "$output" ]; then + echo "${port} is already in use" + exit 1 + fi +done + +script_dir=$(realpath "$(dirname "${BASH_SOURCE[0]}")") +root_dir=$(dirname "$script_dir") +export PYTHONPATH=${root_dir}:${PYTHONPATH} + mkdir -p log rm -rf console.log log/* rm -rf /dev/shm/* @@ -49,7 +62,7 @@ echo "start serving ..." tritonserver --exit-timeout-secs 100 --cuda-memory-pool-byte-size 0:0 --cuda-memory-pool-byte-size 1:0 \ --cuda-memory-pool-byte-size 2:0 --cuda-memory-pool-byte-size 3:0 --cuda-memory-pool-byte-size 4:0 \ --cuda-memory-pool-byte-size 5:0 --cuda-memory-pool-byte-size 6:0 --cuda-memory-pool-byte-size 7:0 \ - --pinned-memory-pool-byte-size 0 --model-repository llm_model/ \ + --pinned-memory-pool-byte-size 0 --model-repository trition_server_model \ --allow-http false \ --grpc-port=${GRPC_PORT} \ --metrics-port=${METRICS_PORT} \ diff --git a/llm/server/server/engine/infer.py b/llm/server/server/engine/infer.py index 5d1f9bd33b..f94fac6e56 100644 --- a/llm/server/server/engine/infer.py +++ b/llm/server/server/engine/infer.py @@ -27,10 +27,11 @@ import paddle.distributed.fleet as fleet from paddlenlp.utils.llm_utils import get_rotary_position_embedding from paddlenlp_ops import step_paddle + from server.data.processor import DataProcessor from server.engine.config import Config from server.utils import get_logger -from task_queue_manager import TaskQueueManager +from server.engine.task_queue_manager import TaskQueueManager File_Path = os.path.realpath(sys.argv[0]) Dir_Path = os.path.dirname(File_Path) diff --git a/llm/server/server/triton_server.py b/llm/server/server/triton_server.py index 12024c251a..f8f43b182a 100644 --- a/llm/server/server/triton_server.py +++ b/llm/server/server/triton_server.py @@ -408,23 +408,6 @@ def _update_metrics(self): self.metrics["available_resource"].set(block_num * 1.0 / self.cfg.max_block_num) - def _get_current_server_info(self): - """ - get server info - """ - available_batch_size = min(self.cfg.max_prefill_batch, - self.engine.available_batch()) - available_block_num = self.engine.available_block_num() - server_info = { - "block_size": int(self.cfg.block_size), - "block_num": int(available_block_num), - "dec_token_num": int(self.cfg.dec_token_num), - "available_resource": - 1.0 * available_block_num / self.cfg.max_block_num, - "max_batch_size": int(available_batch_size), - } - return server_info - def _send_result(result_dict, sender, end_flag=0): """ diff --git a/llm/server/server/utils.py b/llm/server/server/utils.py index bb80f6b0a4..fc6e50cecf 100644 --- a/llm/server/server/utils.py +++ b/llm/server/server/utils.py @@ -135,6 +135,9 @@ def get_logger(name, file_name, without_formater=False): get logger """ log_dir = os.getenv("FD_LOG_DIR", default="log") + if not os.path.exists(log_dir): + os.mkdir(log_dir) + is_debug = int(os.getenv("FD_DEBUG", default=0)) logger = logging.getLogger(name) if is_debug: @@ -142,10 +145,8 @@ def get_logger(name, file_name, without_formater=False): else: logger.setLevel(level=logging.INFO) - LOG_FILE = "{0}/{1}".format(log_dir, file_name) backup_count = int(os.getenv("FD_LOG_BACKUP_COUNT", 7)) - handler = DailyRotatingFileHandler(LOG_FILE, backupCount=backup_count) - + handler = DailyRotatingFileHandler("{0}/{1}".format(log_dir, file_name), backupCount=backup_count) formatter = logging.Formatter( "%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s" ) diff --git a/llm/server/trition_server_model/model/1/model.py b/llm/server/trition_server_model/model/1/model.py new file mode 100644 index 0000000000..eb15091e23 --- /dev/null +++ b/llm/server/trition_server_model/model/1/model.py @@ -0,0 +1 @@ +from server.triton_server import TritonPythonModel diff --git a/llm/server/config/config.pbtxt b/llm/server/trition_server_model/model/config.pbtxt similarity index 100% rename from llm/server/config/config.pbtxt rename to llm/server/trition_server_model/model/config.pbtxt