diff --git a/llm/fastdeploy_llm/engine.py b/llm/fastdeploy_llm/engine.py index 8661893668..7208fed7de 100644 --- a/llm/fastdeploy_llm/engine.py +++ b/llm/fastdeploy_llm/engine.py @@ -376,40 +376,64 @@ def dy_input_preprocess(inputs): """ stop_flags = inputs["dyinput_flags"] dec_length = inputs["seq_len_decoder"] + enc_length = inputs["seq_len_encoder"] bsz = len(stop_flags) tgt_pos = paddle.ones(shape=(bsz, 2, 1), dtype="int64") for i in range(bsz): if stop_flags[i] == 1: - length = int(dec_length[i, 0]) + length = int(enc_length[i, 0]) if args.is_ptuning: model_id = inputs['model_id'][i] - if not model_id: - attention_mask[i, 0, :length, : - max_prefix_len] = paddle.zeros( - [1, length, max_prefix_len], - dtype=args.dtype) + if "chatglm" in args.architecture: + attention_mask[i, 0, :length, : length+max_prefix_len] = 1 + attention_mask[i, 0, :length - 1, length+max_prefix_len - 1] = 0 + tgt_pos[i, 0, 0] = paddle.to_tensor( + [length], dtype="int64") + + if not model_id: + tgt_generation_mask[i, 0, 0, max_prefix_len : length + max_prefix_len] = paddle.ones( + shape=[1, length], dtype=args.dtype + ) + else: + tgt_generation_mask[i, 0, 0, : length + max_prefix_len] = paddle.ones( + shape=[1, length + max_prefix_len], dtype=args.dtype + ) else: - attention_mask[i, 0, :length, : - max_prefix_len] = paddle.ones( - [1, length, max_prefix_len], - dtype=args.dtype) - attention_mask[i, 0, :length, max_prefix_len:max_prefix_len + - length] = paddle.tril( - paddle.ones( - shape=[length, length], - dtype=args.dtype)) - position_ids[i, :max_prefix_len] = 0 - position_ids[i, max_prefix_len:max_prefix_len + inputs[ - "input_ids"].shape[1]] = paddle.arange(inputs["input_ids"] - .shape[1]) - tgt_generation_mask[i, 0, 0, :max_prefix_len + - length] = paddle.ones( - shape=[1, max_prefix_len + length], + if not model_id: + attention_mask[i, 0, :length, : + max_prefix_len] = paddle.zeros( + [1, length, max_prefix_len], + dtype=args.dtype) + tgt_generation_mask[i, 0, 0, max_prefix_len : length + max_prefix_len] = paddle.ones( + shape=[1, length], dtype=args.dtype + ) + else: + attention_mask[i, 0, :length, : + max_prefix_len] = paddle.ones( + [1, length, max_prefix_len], dtype=args.dtype) + tgt_generation_mask[i, 0, 0, :max_prefix_len + + length] = paddle.ones( + shape=[1, max_prefix_len + length], + dtype=args.dtype) + + attention_mask[i, 0, :length, max_prefix_len:max_prefix_len + + length] = paddle.tril( + paddle.ones( + shape=[length, length], + dtype=args.dtype)) + position_ids[i, :max_prefix_len] = 0 + position_ids[i, max_prefix_len:max_prefix_len + inputs[ + "input_ids"].shape[1]] = paddle.arange(inputs["input_ids"] + .shape[1]) + else: if "chatglm" in args.architecture: + # todoļ¼š check alignment with paddlenlp: + # attention_mask[i, 0, :length, :length] = 1 + # attention_mask[i, 0, :length - 1, length - 1] = 0 attention_mask[i, 0, :length, :length] = 0 attention_mask[i, 0, :length - 1, length - 1] = 1 tgt_pos[i, 0, 0] = paddle.to_tensor( @@ -434,7 +458,7 @@ def dy_input_preprocess(inputs): inputs["tgt_generation_mask"] = tgt_generation_mask if "chatglm" in args.architecture: inputs["tgt_pos"] = tgt_pos - inputs["position_ids"] = generate_position_ids_for_chatglm(dec_length) + inputs["position_ids"] = generate_position_ids_for_chatglm(enc_length) if args.is_ptuning: prefix_caches = [] for model_id in inputs['model_id']: