diff --git a/llm/fastdeploy_llm/engine.py b/llm/fastdeploy_llm/engine.py index 8a96f8f7de..027e8dbdad 100644 --- a/llm/fastdeploy_llm/engine.py +++ b/llm/fastdeploy_llm/engine.py @@ -372,7 +372,7 @@ def get_alibi_slopes(num_heads): alibi_decoder + (1 - inputs["tgt_generation_mask"]) * paddle.finfo(inputs["tgt_generation_mask"].dtype).min) attention_mask = inputs["attention_mask"] - tgt_generation_mask = inputs["tgt_generation_mask"] + tgt_generation_mask = inputs["tgt_generation_mask"] def dy_input_preprocess(inputs): @@ -394,70 +394,70 @@ def dy_input_preprocess(inputs): if args.is_ptuning: model_id = inputs['model_id'][i] if "chatglm" in args.architecture: - attention_mask[i, 0, :length, : length+max_prefix_len] = 1 - attention_mask[i, 0, :length - 1, length+max_prefix_len - 1] = 0 + attention_mask[i, 0, :length, :length + max_prefix_len] = 1 + attention_mask[i, 0, :length - 1, length + max_prefix_len - + 1] = 0 tgt_pos[i, 0, 0] = paddle.to_tensor( [length], dtype="int64") - + if not model_id: - tgt_generation_mask[i, 0, 0, max_prefix_len : length + max_prefix_len] = paddle.ones( - shape=[1, length], dtype=args.dtype - ) + tgt_generation_mask[i, 0, 0, max_prefix_len:length + + max_prefix_len] = paddle.ones( + shape=[1, length], + dtype=args.dtype) else: - tgt_generation_mask[i, 0, 0, : length + max_prefix_len] = paddle.ones( - shape=[1, length + max_prefix_len], dtype=args.dtype - ) + tgt_generation_mask[ + i, 0, 0, :length + max_prefix_len] = paddle.ones( + shape=[1, length + max_prefix_len], + dtype=args.dtype) else: if "bloom" in args.architecture: - attention_mask[i, :, :length, :length] = paddle.tril( + attention_mask[i, :, :length, :length] = paddle.tril( paddle.ones( - shape=[length, length], dtype=args.dtype)) + shape=[length, length], dtype=args.dtype)) if not model_id: attention_mask[i, :, :length, : - max_prefix_len] = paddle.zeros( - [1, length, max_prefix_len], - dtype=args.dtype) - tgt_generation_mask[i, 0, 0, max_prefix_len : length + max_prefix_len] = paddle.ones( - shape=[1, length], dtype=args.dtype - ) + max_prefix_len] = paddle.zeros( + [1, length, max_prefix_len], + dtype=args.dtype) + tgt_generation_mask[i, 0, 0, max_prefix_len:length + + max_prefix_len] = paddle.ones( + shape=[1, length], + dtype=args.dtype) else: attention_mask[i, :, :length, : - max_prefix_len] = paddle.ones( - [1, length, max_prefix_len], - dtype=args.dtype) - tgt_generation_mask[i, 0, 0, :max_prefix_len + - length] = paddle.ones( - shape=[1, max_prefix_len + length], - dtype=args.dtype) - - attention_mask[i, :, :length, max_prefix_len:max_prefix_len + - length] = paddle.tril( - paddle.ones( - shape=[length, length], - dtype=args.dtype)) + max_prefix_len] = paddle.ones( + [1, length, max_prefix_len], + dtype=args.dtype) + tgt_generation_mask[ + i, 0, 0, :max_prefix_len + length] = paddle.ones( + shape=[1, max_prefix_len + length], + dtype=args.dtype) + + attention_mask[i, :, :length, max_prefix_len:max_prefix_len + + length] = paddle.tril( + paddle.ones( + shape=[length, length], + dtype=args.dtype)) position_ids[i, :max_prefix_len] = 0 position_ids[i, max_prefix_len:max_prefix_len + inputs[ - "input_ids"].shape[1]] = paddle.arange(inputs["input_ids"] - .shape[1]) + "input_ids"].shape[1]] = paddle.arange(inputs[ + "input_ids"].shape[1]) if "bloom" in args.architecture: - tgt_generation_mask[i, :, 0, :max_prefix_len + - length] = paddle.ones( - shape=[1, max_prefix_len + length], - dtype=args.dtype) - arange_tensor_encoder[i, :, :length + max_prefix_len] = paddle.arange( - length + max_prefix_len).astype(args.dtype) - - + tgt_generation_mask[ + i, :, 0, :max_prefix_len + length] = paddle.ones( + shape=[1, max_prefix_len + length], + dtype=args.dtype) + arange_tensor_encoder[ + i, :, :length + max_prefix_len] = paddle.arange( + length + max_prefix_len).astype(args.dtype) else: if "chatglm" in args.architecture: - attention_mask[i, 0, :length, :length] = 1 + attention_mask[i, 0, :length, :length] = 1 attention_mask[i, 0, :length - 1, length - 1] = 0 - tgt_pos[i, 0, 0] = paddle.to_tensor( - [length], dtype="int64") tgt_generation_mask[i, 0, 0, :length] = paddle.ones( shape=[1, length], dtype=args.dtype) else: - position_ids[i, :length] = paddle.arange(length) attention_mask[i, 0, :length, :length] = paddle.tril( paddle.ones( @@ -475,7 +475,6 @@ def dy_input_preprocess(inputs): inputs["position_ids"] = position_ids inputs["tgt_generation_mask"] = tgt_generation_mask if "chatglm" in args.architecture: - inputs["tgt_pos"] = tgt_pos inputs["position_ids"] = generate_position_ids_for_chatglm(enc_length) if args.is_ptuning: prefix_caches = [] diff --git a/llm/fastdeploy_llm/model.py b/llm/fastdeploy_llm/model.py index d80721e964..bad7d79820 100644 --- a/llm/fastdeploy_llm/model.py +++ b/llm/fastdeploy_llm/model.py @@ -298,9 +298,9 @@ def _update_task_results(self, tasks): for i in range(len(tasks)): tasks[i].decode_status["finished_id"] = int(info["finished_ids"][ i]) - # TODO JiangJiajun if self.config.is_arch("chatglm"): - tasks[i].decode_status["tgt_pos"] = int(info["tgt_pos"][i][0]) + tasks[i].decode_status["tgt_pos"] = info["tgt_pos"][i].flatten( + ).tolist() else: tasks[i].decode_status["tgt_pos"] = int(info["tgt_pos"][i]) @@ -398,9 +398,14 @@ def _add_dynamic_batching_inputs(self, inputs, tasks, stop_nums): inputs["model_id"].append(tasks[i].model_id) length = inputs["num_input_tokens"][i] if tasks[i].status == TaskStatus.DECODING: - tgt_pos.append(tasks[i].decode_status["tgt_pos"]) if self.config.is_arch("chatglm"): - tgt_pos.append(1) + tgt_pos += [ + tasks[i].decode_status["seq_lens_decoder"] - + tasks[i].decode_status["step_idx"] + 1, + tasks[i].decode_status["step_idx"] + ] + else: + tgt_pos.append(tasks[i].decode_status["tgt_pos"]) sequence_lengths_encoder[i] = 0 sequence_lengths_decoder[i] = tasks[i].decode_status[ "seq_lens_decoder"] diff --git a/llm/fastdeploy_llm/serving/serving_model.py b/llm/fastdeploy_llm/serving/serving_model.py index 573ada1cd2..b12ab6e912 100644 --- a/llm/fastdeploy_llm/serving/serving_model.py +++ b/llm/fastdeploy_llm/serving/serving_model.py @@ -23,12 +23,6 @@ class ServingModel: def __init__(self, config): self.config = config - if self.config.is_arch("chatglm") or self.config.is_arch("bloom"): - logger.warning( - "Dynamic batching will be disabled for model ChatGLM/BLOOM now!" - ) - self.config.disable_dynamic_batching = 1 - logger.info("=============== Debug Information ===============") for k, v in self.config.__dict__.items(): logger.info("{:<20}:{:<6}{}".format(k, "", v)) diff --git a/llm/fastdeploy_llm/serving/triton_model.py b/llm/fastdeploy_llm/serving/triton_model.py index 9d6fc8b834..8c4914ff49 100644 --- a/llm/fastdeploy_llm/serving/triton_model.py +++ b/llm/fastdeploy_llm/serving/triton_model.py @@ -122,11 +122,12 @@ def execute(self, requests): continue # 2. validate the deserializing process + task = Task() try: task.from_dict(data) except Exception as e: error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError( - "There's error while deserializing data from reqeust, error={}". + "There's error while deserializing data from request, error={}". format(e))) res_sender = request.get_response_sender() res_sender.send( @@ -135,7 +136,6 @@ def execute(self, requests): continue # 3. check if exists task id conflict - task = Task() if task.task_id is None: task.task_id = str(uuid.uuid4()) if task.task_id in self.response_handler: