Skip to content

Commit

Permalink
[LLM] Support dynamic batching for chatglm (#2251)
Browse files Browse the repository at this point in the history
* [LLM] Support dynamic batching for chatglm

* fix bug in triton model

* fix bug

* fix bug
  • Loading branch information
jiangjiajun authored Oct 20, 2023
1 parent a5a261b commit 4c21588
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 58 deletions.
91 changes: 45 additions & 46 deletions llm/fastdeploy_llm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def get_alibi_slopes(num_heads):
alibi_decoder + (1 - inputs["tgt_generation_mask"]) *
paddle.finfo(inputs["tgt_generation_mask"].dtype).min)
attention_mask = inputs["attention_mask"]
tgt_generation_mask = inputs["tgt_generation_mask"]
tgt_generation_mask = inputs["tgt_generation_mask"]


def dy_input_preprocess(inputs):
Expand All @@ -394,70 +394,70 @@ def dy_input_preprocess(inputs):
if args.is_ptuning:
model_id = inputs['model_id'][i]
if "chatglm" in args.architecture:
attention_mask[i, 0, :length, : length+max_prefix_len] = 1
attention_mask[i, 0, :length - 1, length+max_prefix_len - 1] = 0
attention_mask[i, 0, :length, :length + max_prefix_len] = 1
attention_mask[i, 0, :length - 1, length + max_prefix_len -
1] = 0
tgt_pos[i, 0, 0] = paddle.to_tensor(
[length], dtype="int64")

if not model_id:
tgt_generation_mask[i, 0, 0, max_prefix_len : length + max_prefix_len] = paddle.ones(
shape=[1, length], dtype=args.dtype
)
tgt_generation_mask[i, 0, 0, max_prefix_len:length +
max_prefix_len] = paddle.ones(
shape=[1, length],
dtype=args.dtype)
else:
tgt_generation_mask[i, 0, 0, : length + max_prefix_len] = paddle.ones(
shape=[1, length + max_prefix_len], dtype=args.dtype
)
tgt_generation_mask[
i, 0, 0, :length + max_prefix_len] = paddle.ones(
shape=[1, length + max_prefix_len],
dtype=args.dtype)
else:
if "bloom" in args.architecture:
attention_mask[i, :, :length, :length] = paddle.tril(
attention_mask[i, :, :length, :length] = paddle.tril(
paddle.ones(
shape=[length, length], dtype=args.dtype))
shape=[length, length], dtype=args.dtype))
if not model_id:
attention_mask[i, :, :length, :
max_prefix_len] = paddle.zeros(
[1, length, max_prefix_len],
dtype=args.dtype)
tgt_generation_mask[i, 0, 0, max_prefix_len : length + max_prefix_len] = paddle.ones(
shape=[1, length], dtype=args.dtype
)
max_prefix_len] = paddle.zeros(
[1, length, max_prefix_len],
dtype=args.dtype)
tgt_generation_mask[i, 0, 0, max_prefix_len:length +
max_prefix_len] = paddle.ones(
shape=[1, length],
dtype=args.dtype)
else:
attention_mask[i, :, :length, :
max_prefix_len] = paddle.ones(
[1, length, max_prefix_len],
dtype=args.dtype)
tgt_generation_mask[i, 0, 0, :max_prefix_len +
length] = paddle.ones(
shape=[1, max_prefix_len + length],
dtype=args.dtype)
attention_mask[i, :, :length, max_prefix_len:max_prefix_len +
length] = paddle.tril(
paddle.ones(
shape=[length, length],
dtype=args.dtype))
max_prefix_len] = paddle.ones(
[1, length, max_prefix_len],
dtype=args.dtype)
tgt_generation_mask[
i, 0, 0, :max_prefix_len + length] = paddle.ones(
shape=[1, max_prefix_len + length],
dtype=args.dtype)

attention_mask[i, :, :length, max_prefix_len:max_prefix_len
+ length] = paddle.tril(
paddle.ones(
shape=[length, length],
dtype=args.dtype))
position_ids[i, :max_prefix_len] = 0
position_ids[i, max_prefix_len:max_prefix_len + inputs[
"input_ids"].shape[1]] = paddle.arange(inputs["input_ids"]
.shape[1])
"input_ids"].shape[1]] = paddle.arange(inputs[
"input_ids"].shape[1])
if "bloom" in args.architecture:
tgt_generation_mask[i, :, 0, :max_prefix_len +
length] = paddle.ones(
shape=[1, max_prefix_len + length],
dtype=args.dtype)
arange_tensor_encoder[i, :, :length + max_prefix_len] = paddle.arange(
length + max_prefix_len).astype(args.dtype)


tgt_generation_mask[
i, :, 0, :max_prefix_len + length] = paddle.ones(
shape=[1, max_prefix_len + length],
dtype=args.dtype)
arange_tensor_encoder[
i, :, :length + max_prefix_len] = paddle.arange(
length + max_prefix_len).astype(args.dtype)
else:
if "chatglm" in args.architecture:
attention_mask[i, 0, :length, :length] = 1
attention_mask[i, 0, :length, :length] = 1
attention_mask[i, 0, :length - 1, length - 1] = 0
tgt_pos[i, 0, 0] = paddle.to_tensor(
[length], dtype="int64")
tgt_generation_mask[i, 0, 0, :length] = paddle.ones(
shape=[1, length], dtype=args.dtype)
else:

position_ids[i, :length] = paddle.arange(length)
attention_mask[i, 0, :length, :length] = paddle.tril(
paddle.ones(
Expand All @@ -475,7 +475,6 @@ def dy_input_preprocess(inputs):
inputs["position_ids"] = position_ids
inputs["tgt_generation_mask"] = tgt_generation_mask
if "chatglm" in args.architecture:
inputs["tgt_pos"] = tgt_pos
inputs["position_ids"] = generate_position_ids_for_chatglm(enc_length)
if args.is_ptuning:
prefix_caches = []
Expand Down
13 changes: 9 additions & 4 deletions llm/fastdeploy_llm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ def _update_task_results(self, tasks):
for i in range(len(tasks)):
tasks[i].decode_status["finished_id"] = int(info["finished_ids"][
i])
# TODO JiangJiajun
if self.config.is_arch("chatglm"):
tasks[i].decode_status["tgt_pos"] = int(info["tgt_pos"][i][0])
tasks[i].decode_status["tgt_pos"] = info["tgt_pos"][i].flatten(
).tolist()
else:
tasks[i].decode_status["tgt_pos"] = int(info["tgt_pos"][i])

Expand Down Expand Up @@ -398,9 +398,14 @@ def _add_dynamic_batching_inputs(self, inputs, tasks, stop_nums):
inputs["model_id"].append(tasks[i].model_id)
length = inputs["num_input_tokens"][i]
if tasks[i].status == TaskStatus.DECODING:
tgt_pos.append(tasks[i].decode_status["tgt_pos"])
if self.config.is_arch("chatglm"):
tgt_pos.append(1)
tgt_pos += [
tasks[i].decode_status["seq_lens_decoder"] -
tasks[i].decode_status["step_idx"] + 1,
tasks[i].decode_status["step_idx"]
]
else:
tgt_pos.append(tasks[i].decode_status["tgt_pos"])
sequence_lengths_encoder[i] = 0
sequence_lengths_decoder[i] = tasks[i].decode_status[
"seq_lens_decoder"]
Expand Down
6 changes: 0 additions & 6 deletions llm/fastdeploy_llm/serving/serving_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@ class ServingModel:
def __init__(self, config):
self.config = config

if self.config.is_arch("chatglm") or self.config.is_arch("bloom"):
logger.warning(
"Dynamic batching will be disabled for model ChatGLM/BLOOM now!"
)
self.config.disable_dynamic_batching = 1

logger.info("=============== Debug Information ===============")
for k, v in self.config.__dict__.items():
logger.info("{:<20}:{:<6}{}".format(k, "", v))
Expand Down
4 changes: 2 additions & 2 deletions llm/fastdeploy_llm/serving/triton_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,12 @@ def execute(self, requests):
continue

# 2. validate the deserializing process
task = Task()
try:
task.from_dict(data)
except Exception as e:
error_res = pb_utils.InferenceResponse(error=pb_utils.TritonError(
"There's error while deserializing data from reqeust, error={}".
"There's error while deserializing data from request, error={}".
format(e)))
res_sender = request.get_response_sender()
res_sender.send(
Expand All @@ -135,7 +136,6 @@ def execute(self, requests):
continue

# 3. check if exists task id conflict
task = Task()
if task.task_id is None:
task.task_id = str(uuid.uuid4())
if task.task_id in self.response_handler:
Expand Down

0 comments on commit 4c21588

Please sign in to comment.