Skip to content

Commit

Permalink
[LLM] Add prefix for chatglm (#2233)
Browse files Browse the repository at this point in the history
* add prefix cache for chatglm

* support chatglm
  • Loading branch information
rainyfly authored Oct 12, 2023
1 parent 80bb8ed commit 986b233
Showing 1 changed file with 47 additions and 23 deletions.
70 changes: 47 additions & 23 deletions llm/fastdeploy_llm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,40 +376,64 @@ def dy_input_preprocess(inputs):
"""
stop_flags = inputs["dyinput_flags"]
dec_length = inputs["seq_len_decoder"]
enc_length = inputs["seq_len_encoder"]
bsz = len(stop_flags)

tgt_pos = paddle.ones(shape=(bsz, 2, 1), dtype="int64")

for i in range(bsz):
if stop_flags[i] == 1:
length = int(dec_length[i, 0])
length = int(enc_length[i, 0])
if args.is_ptuning:
model_id = inputs['model_id'][i]
if not model_id:
attention_mask[i, 0, :length, :
max_prefix_len] = paddle.zeros(
[1, length, max_prefix_len],
dtype=args.dtype)
if "chatglm" in args.architecture:
attention_mask[i, 0, :length, : length+max_prefix_len] = 1
attention_mask[i, 0, :length - 1, length+max_prefix_len - 1] = 0
tgt_pos[i, 0, 0] = paddle.to_tensor(
[length], dtype="int64")

if not model_id:
tgt_generation_mask[i, 0, 0, max_prefix_len : length + max_prefix_len] = paddle.ones(
shape=[1, length], dtype=args.dtype
)
else:
tgt_generation_mask[i, 0, 0, : length + max_prefix_len] = paddle.ones(
shape=[1, length + max_prefix_len], dtype=args.dtype
)
else:
attention_mask[i, 0, :length, :
max_prefix_len] = paddle.ones(
[1, length, max_prefix_len],
dtype=args.dtype)
attention_mask[i, 0, :length, max_prefix_len:max_prefix_len +
length] = paddle.tril(
paddle.ones(
shape=[length, length],
dtype=args.dtype))
position_ids[i, :max_prefix_len] = 0
position_ids[i, max_prefix_len:max_prefix_len + inputs[
"input_ids"].shape[1]] = paddle.arange(inputs["input_ids"]
.shape[1])
tgt_generation_mask[i, 0, 0, :max_prefix_len +
length] = paddle.ones(
shape=[1, max_prefix_len + length],
if not model_id:
attention_mask[i, 0, :length, :
max_prefix_len] = paddle.zeros(
[1, length, max_prefix_len],
dtype=args.dtype)
tgt_generation_mask[i, 0, 0, max_prefix_len : length + max_prefix_len] = paddle.ones(
shape=[1, length], dtype=args.dtype
)
else:
attention_mask[i, 0, :length, :
max_prefix_len] = paddle.ones(
[1, length, max_prefix_len],
dtype=args.dtype)
tgt_generation_mask[i, 0, 0, :max_prefix_len +
length] = paddle.ones(
shape=[1, max_prefix_len + length],
dtype=args.dtype)

attention_mask[i, 0, :length, max_prefix_len:max_prefix_len +
length] = paddle.tril(
paddle.ones(
shape=[length, length],
dtype=args.dtype))
position_ids[i, :max_prefix_len] = 0
position_ids[i, max_prefix_len:max_prefix_len + inputs[
"input_ids"].shape[1]] = paddle.arange(inputs["input_ids"]
.shape[1])

else:
if "chatglm" in args.architecture:
# todo: check alignment with paddlenlp:
# attention_mask[i, 0, :length, :length] = 1
# attention_mask[i, 0, :length - 1, length - 1] = 0
attention_mask[i, 0, :length, :length] = 0
attention_mask[i, 0, :length - 1, length - 1] = 1
tgt_pos[i, 0, 0] = paddle.to_tensor(
Expand All @@ -434,7 +458,7 @@ def dy_input_preprocess(inputs):
inputs["tgt_generation_mask"] = tgt_generation_mask
if "chatglm" in args.architecture:
inputs["tgt_pos"] = tgt_pos
inputs["position_ids"] = generate_position_ids_for_chatglm(dec_length)
inputs["position_ids"] = generate_position_ids_for_chatglm(enc_length)
if args.is_ptuning:
prefix_caches = []
for model_id in inputs['model_id']:
Expand Down

0 comments on commit 986b233

Please sign in to comment.