Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine llm infer #2523

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
up
  • Loading branch information
juncaipeng committed Sep 23, 2024
commit 030d5b5f88260cbbfa781983b4e8ada23f51d1c7
2 changes: 1 addition & 1 deletion llm/benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def parse_args():
print(f"data_queue size: {data_queue.qsize()}")

test_tag = f"{args.tag}-{args.api_type}-wk{args.worker_num}-dn{args.data_num}"
logger = get_logger('benchmark', f'{test_tag}-log')
logger = get_logger('benchmark', f'{test_tag}.log')
logger.info(f"args: {args}")
logger.info(f"test_tag: {test_tag}")

Expand Down
14 changes: 10 additions & 4 deletions llm/server/server/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class Config:

def __init__(self):
self.read_from_env()
self.read_from_config()
self.postprocess()
self.check()

def read_from_env(self):
"""
Expand Down Expand Up @@ -130,10 +133,6 @@ def read_from_env(self):
)
self.generation_config = None

self.read_from_config()
self.postprocess()
self.check()

def postprocess(self):
"""
calculate some parameters
Expand Down Expand Up @@ -234,3 +233,10 @@ def get_unique_name(self, name):

def __str__(self) -> str:
return json.dumps(self.__dict__, indent=4)

cfg_inst = None
def get_global_config():
global cfg_inst
if cfg_inst is None:
cfg_inst = Config()
return cfg_inst
125 changes: 21 additions & 104 deletions llm/server/server/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,90 +25,34 @@
import numpy as np
from server.engine.resource_manager import ResourceManager
from server.engine.task_queue_manager import (TaskQueueManager,
launch_queue_service)
from server.engine.token_processor import TokenProcessor, WarmUpTokenProcessor
launch_task_queue_manager)
from server.engine.token_processor import TokenProcessor
from server.utils import model_server_logger


class Engine(object):
"""
Engine Class
"""
def __init__(self, cfg, token_processor):
def __init__(self, cfg):
self.cfg = cfg
self.resource_manager = ResourceManager(self.cfg)
self.token_processor = token_processor
self.token_processor = TokenProcessor(self.cfg)
self.token_processor.set_resource_manager(self.resource_manager)
self.is_started = False

self._init_engine_flags()
self._finalizer = weakref.finalize(self, self._exit_sub_services)

def start(self):
"""
initialize engine and start sub services
"""
assert not self.is_started, "The engine is already started.!"
start_time = time.time()
self.queue_service = self._start_tasks_queue_service()
self.tasks_queue = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_port)
self.tqm_proc = self._start_task_queue_manager()
self.task_queue_manager = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_port)

self.token_processor.tasks_queue = self.tasks_queue
self.infer_proc = self._start_infer_service()
start_time = time.time()
self.infer_proc = self._start_infer_process()
model_server_logger.info("Waitting infer processes ready...")
while not self._infer_processes_ready():
time.sleep(1)
self.is_started = True

# start warmup
if self.cfg.use_warmup:
model_server_logger.info("Start warmup")
self._set_warmup_token_processor()
self.warmup()
self._del_warmup_token_processor()
model_server_logger.info("Warmup finish")

# start TokenProcessor thread
self.token_processor.run()
model_server_logger.info("Infer processes are launched with {} seconds.".format(time.time() - start_time))

def warmup(self):
"""
construct test tasks and avoid out of memory problem in the infer process
"""
# get eos_token_id
from server.data.processor import DataProcessor
eos_token_ids = DataProcessor().get_eos_tokens()

# construct test tasks
res_task = []
for j in range(2 * self.cfg.max_batch_size):
data = {
"input_ids": [5],
"req_id": j,
"max_dec_len": self.cfg.dec_len_limit,
"min_dec_len": int(self.cfg.dec_len_limit * 0.5) + 1,
"eos_token_ids": eos_token_ids
}
res_task.append(data)
for j in range(2 * self.cfg.max_prefill_batch):
data = {
"input_ids": [5] * self.cfg.seq_len_limit,
"req_id": j + 2 * self.cfg.max_batch_size,
"max_dec_len": 1,
"min_dec_len": 1,
"eos_token_ids": eos_token_ids
}
res_task.append(data)

for x in res_task:
while self.available_batch() == 0 or not self.insert_tasks([x]):
time.sleep(0.0002)

self.token_processor._is_blocking = False
# wait for all tasks finished
while not self.all_tasks_finished():
time.sleep(1)
self._finalizer = weakref.finalize(self, self._exit_sub_services)

def insert_tasks(self, tasks):
"""
Expand Down Expand Up @@ -158,13 +102,9 @@ def insert_tasks(self, tasks):
if not tasks:
return False

self.token_processor.number_of_tasks += len(tasks)
for i in range(len(tasks)):
self.token_processor.number_of_input_tokens += len(tasks[i]["input_ids"])

req_ids = [t["req_id"] for t in tasks]
model_server_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
self.tasks_queue.put((tasks, self.resource_manager.real_bsz))
self.task_queue_manager.put((tasks, self.resource_manager.real_bsz))
return True

def task_is_finished(self, index):
Expand All @@ -187,7 +127,7 @@ def is_queue_empty(self):
Returns:
return: True if empty, False otherwise
"""
return self.tasks_queue.empty()
return self.task_queue_manager.empty()

def is_resource_sufficient(self, input_token_num):
"""
Expand Down Expand Up @@ -228,29 +168,6 @@ def available_block_num(self):
"""
return self.resource_manager.availabel_block_num()

def _set_warmup_token_processor(self):
"""
set token_processor for warmup
"""
self.token_processor_backup = self.token_processor
self.token_processor = WarmUpTokenProcessor(self.cfg)
self.token_processor.set_resource_manager(self.resource_manager)
self.token_processor.tasks_queue = self.tasks_queue

# start TokenProcessor thread
self.token_processor.run()

def _del_warmup_token_processor(self):
"""
delete token_processor for warmup
"""
self.token_processor.stop()
del self.token_processor

# reset token_processor
self.token_processor = self.token_processor_backup
del self.token_processor_backup

def _infer_processes_ready(self):
"""
judge if all infer processes are ready
Expand Down Expand Up @@ -341,20 +258,20 @@ def _exit_sub_services(self):
"""
exit sub services
"""
if hasattr(self, "queue_service") and self.queue_service is not None:
self.queue_service.terminate()
self.queue_service.join()
if hasattr(self, "tqm_proc") and self.tqm_proc is not None:
self.tqm_proc.terminate()
self.tqm_proc.join()
if hasattr(self, "infer_proc") and self.infer_proc is not None:
os.killpg(self.infer_proc.pid, signal.SIGTERM)

def _start_tasks_queue_service(self):
def _start_task_queue_manager(self):
"""
start tasks queue service

Returns:
p: process handle
"""
p = multiprocessing.Process(target=launch_queue_service, args=(self.cfg.infer_port, self.cfg.mp_num))
p = multiprocessing.Process(target=launch_task_queue_manager, args=(self.cfg.infer_port, self.cfg.mp_num))
p.start()
time.sleep(0.3)
if p.is_alive():
Expand All @@ -366,9 +283,9 @@ def _start_tasks_queue_service(self):
raise Exception(error_msg)
return p

def _start_gpu_infer_service(self):
def _start_gpu_infer_process(self):
"""
start gpu infer service
start gpu infer process

Returns:
p: process handle
Expand All @@ -394,8 +311,8 @@ def _start_gpu_infer_service(self):
)
return p

def _start_infer_service(self):
def _start_infer_process(self):
"""
start infer service
start infer process
"""
return self._start_gpu_infer_service()
return self._start_gpu_infer_process()
2 changes: 1 addition & 1 deletion llm/server/server/engine/task_queue_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get(self):
return input_list, read_finish


def launch_queue_service(port, num_workers):
def launch_task_queue_manager(port, num_workers):
"""
Start the process communication queue service

Expand Down
70 changes: 6 additions & 64 deletions llm/server/server/engine/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import queue
import threading
import time
import traceback
Expand All @@ -32,20 +33,17 @@ def __init__(self, cfg):
import paddle
paddle.device.set_device("cpu")
self.cfg = cfg
self.out_queue = queue.Queue()
self.resource_manager = None
# record all tokens for each request
self.all_tokens = [[] for _ in range(self.cfg.max_batch_size)]

self.tokens_counter = Counter()
self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
self.worker = None

self.record_time_interval = int(os.getenv("RECORD_TIME_INTERVAL", "600"))
assert self.record_time_interval < 3600, "The RECORD_TIME_INTERVAL cannot exceed 3600."
self.statics_start_time = time.time()
self.number_of_tasks = 0
self.number_of_input_tokens = 0
self.number_of_output_tokens = 0
self.worker = threading.Thread(target=self.process_sampling_results, args=())
self.worker.daemon = True
self.worker.start()

def set_resource_manager(self, resource_manager):
"""
Expand All @@ -57,18 +55,6 @@ def set_resource_manager(self, resource_manager):
assert self.resource_manager is None, "The resource manager is not None, cannot set again."
self.resource_manager = resource_manager

def run(self):
"""
start thread to get tokens
"""
assert self.resource_manager is not None, "The resource manager is None, cannot run."
if self.worker is not None:
raise Exception("Worker is already running!")

self.worker = threading.Thread(target=self.process_sampling_results, args=())
self.worker.daemon = True
self.worker.start()

def process_sampling_results(self):
"""
read tokens from paddle inference engine and process
Expand All @@ -93,13 +79,7 @@ def postprocess(self, batch_result, exist_finished_task=False):
batch_result (list): batch results
exist_finished_task (bool): whether there is a finished task
"""
result_dir = "./generate_token_results"
if not os.path.exists(result_dir):
os.makedirs(result_dir)
for result in batch_result:
result_file = os.path.join(result_dir, result["req_id"])
with open(result_file, "a") as f:
f.write("{}\n".format(result))
self.out_queue.put(batch_result)

def _get_single_result(self, i, task_id, token_id, task):
"""
Expand Down Expand Up @@ -198,7 +178,6 @@ def _process_batch_output(self):
if token_id not in task["eos_token_ids"]:
self.all_tokens[i].append(token_id)

self.number_of_output_tokens += 1
if token_id in task["eos_token_ids"]:
self._recycle_resources(task_id, i, task)
model_server_logger.info("req_id: {0} finished".format(task_id))
Expand All @@ -207,40 +186,3 @@ def _process_batch_output(self):
batch_result.append(result)

self.postprocess(batch_result, exist_finished_task)


class WarmUpTokenProcessor(TokenProcessor):
"""
Warmup Processor
"""
def __init__(self, cfg):
super().__init__(cfg)
self._is_running = True
self._is_blocking = True

def postprocess(self, batch_result, exist_finished_task=False):
pass

def process_sampling_results(self):
"""
get output from model and process it
"""
while self._is_running:
try:
rank_id = 0
get_output(self.output_tokens, rank_id, self._is_blocking)

if self.output_tokens[0, 0] == -2:
continue
self._process_batch_output()
except Exception as e:
model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))

def stop(self):
"""
stop warm up thread
"""
self._is_running = False
self.worker.join()
model_server_logger.info("warm up thread stop")
del self.worker
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import uvicorn
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response
from server.engine.config import Config
from server.engine.config import get_global_config
from server.utils import get_logger

app = FastAPI()
env_config = Config()
env_config = get_global_config()
logger = get_logger("health_checker", "health_checker.log")


Expand Down
Loading