-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[New Features] support dynamic src_length #7740
Changes from 6 commits
dbec012
9d788c8
c88fc40
0938c7f
a4772b0
afb81cc
5dade6a
ea9f5b3
f4cd00c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
import os | ||
import socket | ||
from contextlib import closing | ||
from dataclasses import dataclass, field | ||
from dataclasses import asdict, dataclass, field | ||
from time import sleep | ||
|
||
import requests | ||
|
@@ -68,6 +68,7 @@ def __init__(self, args: ServerArgument, predictor: BasePredictor): | |
self.args.flask_port + port_interval * predictor.tensor_parallel_rank, | ||
self.args.flask_port + port_interval * (predictor.tensor_parallel_rank + 1), | ||
) | ||
self.total_max_length = predictor.config.src_length + predictor.config.max_length | ||
|
||
if self.predictor.tensor_parallel_rank == 0: | ||
# fetch port info | ||
|
@@ -123,16 +124,44 @@ def streaming(data): | |
|
||
# build chat template | ||
if self.predictor.tokenizer.chat_template is not None: | ||
history = json.loads(history) | ||
if not history: | ||
history = [] | ||
# also support history data | ||
elif isinstance(history, str): | ||
history = json.loads(history) | ||
|
||
assert len(history) % 2 == 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果len(history) == 0: 下面的都不用做了,很担心成了 query = [[]] |
||
chat_query = [] | ||
for idx in range(0, len(history), 2): | ||
chat_query.append(["", ""]) | ||
chat_query[-1][0], chat_query[-1][1] = history[idx]["utterance"], history[idx + 1]["utterance"] | ||
query = [chat_query] | ||
if isinstance(history[idx], str): | ||
chat_query.append([history[idx], history[idx + 1]]) | ||
elif isinstance(history[idx], dict): | ||
chat_query.append([history[idx]["utterance"], history[idx + 1]["utterance"]]) | ||
else: | ||
raise ValueError( | ||
"history data should be list[str] or list[dict], eg: ['sentence-1', 'sentece-2', ...], or " | ||
"[{'utterance': 'sentence-1'}, {'utterance': 'sentence-2'}, ...]" | ||
) | ||
|
||
# the input of predictor should be batched. | ||
# batched query: [ [[user, bot], [user, bot], ..., [user]] ] | ||
query = [chat_query + [[query]]] | ||
|
||
generation_args = data | ||
self.predictor.config.max_length = generation_args["max_length"] | ||
if "src_length" in generation_args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里 src_length 和 max_length 加起来不能超过最大值吧?,如果超过了需要改一下吧? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 关于这块之前其实是有讨论过的,结论是:分别用 src_length 和 max_length 分别来做输入和输出的最大长度控制,并且不要强制控制。 不强制控制是在于模型可能是存在支持外推的能力,此时用这个就限制了外推的效果了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 那就枚举出支持外推的模型,然后给他不做限制,不能外推的就要限制 |
||
self.predictor.config.src_length = generation_args["src_length"] | ||
|
||
if self.predictor.config.src_length + self.predictor.config.max_length > self.total_max_length: | ||
output = { | ||
"error_code": 1, | ||
"error_msg": f"The sum of src_length<{self.predictor.config.src_length}> and " | ||
f"max_length<{self.predictor.config.max_length}> should be smaller than or equal to " | ||
f"the maximum position embedding size<{self.total_max_length}>", | ||
} | ||
yield json.dumps(output, ensure_ascii=False) + "\n" | ||
return | ||
|
||
self.predictor.config.top_p = generation_args["top_p"] | ||
self.predictor.config.temperature = generation_args["temperature"] | ||
self.predictor.config.top_k = generation_args["top_k"] | ||
|
@@ -160,13 +189,13 @@ def streaming(data): | |
# refer to: https://github.com/pallets/flask/blob/main/src/flask/app.py#L605 | ||
app.run(host="0.0.0.0", port=self.port, threaded=False) | ||
|
||
def start_ui_service(self, args): | ||
def start_ui_service(self, args, predictor_args): | ||
# do not support start ui service in one command | ||
from multiprocessing import Process | ||
|
||
from gradio_ui import main | ||
|
||
p = Process(target=main, args=(args,)) | ||
p = Process(target=main, args=(args, predictor_args)) | ||
p.daemon = True | ||
p.start() | ||
|
||
|
@@ -194,6 +223,6 @@ def start_ui_service(self, args): | |
server = PredictorServer(server_args, predictor) | ||
|
||
if server.predictor.tensor_parallel_rank == 0: | ||
server.start_ui_service(server_args) | ||
server.start_ui_service(server_args, asdict(predictor.config)) | ||
|
||
server.start_flask_server() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,7 @@ def setup_args(): | |
return args | ||
|
||
|
||
def launch(args): | ||
def launch(args, default_params: dict = {}): | ||
"""Launch characters dialogue demo.""" | ||
|
||
def rollback(state): | ||
|
@@ -42,7 +42,7 @@ def rollback(state): | |
shown_context = get_shown_context(context) | ||
return utterance, shown_context, context, state | ||
|
||
def regen(state, top_k, top_p, temperature, repetition_penalty, max_length): | ||
def regen(state, top_k, top_p, temperature, repetition_penalty, max_length, src_length): | ||
"""Regenerate response.""" | ||
context = state.setdefault("context", []) | ||
if len(context) < 2: | ||
|
@@ -74,13 +74,13 @@ def begin(utterance, state): | |
shown_context = get_shown_context(context) | ||
return utterance, shown_context, context, state | ||
|
||
def infer(utterance, state, top_k, top_p, temperature, repetition_penalty, max_length): | ||
def infer(utterance, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length): | ||
"""Model inference.""" | ||
utterance = utterance.strip().replace("<br>", "\n") | ||
context = state.setdefault("context", []) | ||
|
||
if not utterance: | ||
gr.Warning("invalid inputs111") | ||
gr.Warning("invalid inputs") | ||
# gr.Warning("请输入有效问题") | ||
shown_context = get_shown_context(context) | ||
return None, shown_context, context, state | ||
|
@@ -93,11 +93,17 @@ def infer(utterance, state, top_k, top_p, temperature, repetition_penalty, max_l | |
"temperature": temperature, | ||
"repetition_penalty": repetition_penalty, | ||
"max_length": max_length, | ||
"src_length": src_length, | ||
"min_length": 1, | ||
} | ||
res = requests.post(f"http://0.0.0.0:{args.flask_port}/api/chat", json=data, stream=True) | ||
for line in res.iter_lines(): | ||
result = json.loads(line) | ||
if result["error_code"] != 0: | ||
gr.Warning(result["error_msg"]) | ||
shown_context = get_shown_context(context) | ||
return None, shown_context, context, state | ||
|
||
bot_response = result["result"]["response"] | ||
|
||
# replace \n with br: https://github.com/gradio-app/gradio/issues/4344 | ||
|
@@ -156,29 +162,53 @@ def get_shown_context(context): | |
with gr.Row(): | ||
with gr.Column(scale=1): | ||
top_k = gr.Slider( | ||
minimum=1, maximum=100, value=50, step=1, label="Top-k", info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。" | ||
minimum=0, | ||
maximum=default_params.get("top_k", 20), | ||
value=0, | ||
step=1, | ||
label="Top-k", | ||
info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。", | ||
) | ||
top_p = gr.Slider( | ||
minimum=0, maximum=1, value=0.7, step=0.05, label="Top-p", info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。" | ||
minimum=0, | ||
maximum=1, | ||
value=default_params.get("top_p", 0.7), | ||
step=0.05, | ||
label="Top-p", | ||
info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。", | ||
) | ||
temperature = gr.Slider( | ||
minimum=0.05, | ||
maximum=1.5, | ||
value=0.95, | ||
value=default_params.get("temperature", 0.95), | ||
step=0.05, | ||
label="Temperature", | ||
info="该参数越小,模型生成结果更加随机,反之生成结果更加确定。", | ||
) | ||
repetition_penalty = gr.Slider( | ||
minimum=0.1, | ||
maximum=10, | ||
value=1.0, | ||
value=default_params.get("repetition_penalty", 1.2), | ||
step=0.05, | ||
label="Repetition Penalty", | ||
info="该参数越大,生成结果重复的概率越低。设置 1 则不开启。", | ||
) | ||
default_src_length = default_params["src_length"] | ||
src_length = gr.Slider( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个往上面挪一点,先src length再max length |
||
minimum=1, | ||
maximum=default_src_length, | ||
value=default_src_length, | ||
step=1, | ||
label="Max Src Length", | ||
info="最大输入长度。", | ||
) | ||
max_length = gr.Slider( | ||
minimum=1, maximum=1024, value=50, step=1, label="Max Length", info="生成结果的最大长度。" | ||
minimum=1, | ||
maximum=default_params["max_length"], | ||
value=50, | ||
step=1, | ||
label="Max Length", | ||
info="生成结果的最大长度。", | ||
) | ||
with gr.Column(scale=4): | ||
state = gr.State({}) | ||
|
@@ -200,7 +230,7 @@ def get_shown_context(context): | |
api_name="chat", | ||
).then( | ||
infer, | ||
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length], | ||
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length], | ||
outputs=[utt_text, context_chatbot, raw_context_json, state], | ||
) | ||
|
||
|
@@ -219,13 +249,13 @@ def get_shown_context(context): | |
) | ||
regen_btn.click( | ||
regen, | ||
inputs=[state, top_k, top_p, temperature, repetition_penalty, max_length], | ||
inputs=[state, top_k, top_p, temperature, repetition_penalty, max_length, src_length], | ||
outputs=[utt_text, context_chatbot, raw_context_json, state], | ||
queue=False, | ||
api_name="chat", | ||
).then( | ||
infer, | ||
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length], | ||
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length], | ||
outputs=[utt_text, context_chatbot, raw_context_json, state], | ||
) | ||
|
||
|
@@ -237,15 +267,15 @@ def get_shown_context(context): | |
api_name="chat", | ||
).then( | ||
infer, | ||
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length], | ||
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length], | ||
outputs=[utt_text, context_chatbot, raw_context_json, state], | ||
) | ||
|
||
block.queue().launch(server_name="0.0.0.0", server_port=args.port, debug=True) | ||
|
||
|
||
def main(args): | ||
launch(args) | ||
def main(args, default_params: dict = {}): | ||
launch(args, default_params) | ||
|
||
|
||
if __name__ == "__main__": | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,10 @@ | |
from utils import ( | ||
dybatch_preprocess, | ||
get_alibi_slopes, | ||
get_default_max_decoding_length, | ||
get_default_max_encoding_length, | ||
get_infer_model_path, | ||
get_model_max_position_embeddings, | ||
get_prefix_tuning_params, | ||
init_chat_template, | ||
load_real_time_tokens, | ||
|
@@ -56,8 +59,8 @@ | |
class PredictorArgument: | ||
model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."}) | ||
model_prefix: str = field(default="model", metadata={"help": "the prefix name of static model"}) | ||
src_length: int = field(default=1024, metadata={"help": "The max length of source text."}) | ||
max_length: int = field(default=2048, metadata={"help": "the max length for decoding."}) | ||
src_length: int = field(default=None, metadata={"help": "The max length of source text."}) | ||
max_length: int = field(default=None, metadata={"help": "the max length for decoding."}) | ||
top_k: int = field(default=0, metadata={"help": "top_k parameter for generation"}) | ||
top_p: float = field(default=0.7, metadata={"help": "top_p parameter for generation"}) | ||
temperature: float = field(default=0.95, metadata={"help": "top_p parameter for generation"}) | ||
|
@@ -885,6 +888,22 @@ def create_predictor( | |
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) | ||
else: | ||
raise ValueError("the `mode` should be one of [dynamic, static]") | ||
|
||
if predictor.config.src_length is None: | ||
predictor.config.src_length = get_default_max_encoding_length(predictor.model_config) | ||
|
||
if predictor.config.max_length is None: | ||
predictor.config.max_length = get_default_max_decoding_length(predictor.model_config) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是不是要判断一下,当用户指定这两个参数的时候,src len + max len 会超过最大值吗?你默认值不可能超过,但是一旦用户自己指定了,有可能超过 |
||
|
||
max_position_embeddings = get_model_max_position_embeddings(predictor.model_config) | ||
if max_position_embeddings is not None: | ||
if predictor.config.src_length + predictor.config.max_length > max_position_embeddings: | ||
raise ValueError( | ||
f"The sum of src_length<{predictor.config.src_length}> and " | ||
f"max_length<{predictor.config.max_length}> should be smaller than or equal to " | ||
f"the maximum position embedding size<{max_position_embeddings}>" | ||
) | ||
|
||
return predictor | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
|
||
import requests | ||
|
||
|
||
def send_request(query, history=None): | ||
data = { | ||
"context": query, | ||
"history": history, | ||
"top_k": 0, | ||
"top_p": 0.7, # 0.0 为 greedy_search | ||
"temperature": 0.95, | ||
"repetition_penalty": 1.3, | ||
"max_length": 100, | ||
"src_length": 100, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 当前flask没有兜底的机制,当src len + max len 无限长的时候,程序直接挂了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里还是没有解决问题,需要给个error message吗? |
||
"min_length": 1, | ||
} | ||
res = requests.post("http://127.0.0.1:8010/api/chat", json=data, stream=True) | ||
text = "" | ||
for line in res.iter_lines(): | ||
result = json.loads(line) | ||
|
||
if result["error_code"] != 0: | ||
text = "error-response" | ||
break | ||
|
||
result = json.loads(line) | ||
bot_response = result["result"]["response"] | ||
|
||
if bot_response["utterance"].endswith("[END]"): | ||
bot_response["utterance"] = bot_response["utterance"][:-5] | ||
text += bot_response["utterance"] | ||
|
||
print("result -> ", text) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. print放到外面吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这样外面每个都需要 print 这个是个示例脚本,只是介绍他们怎么用,这个就没必要吧 |
||
return text | ||
|
||
|
||
send_request("你好啊") | ||
send_request("再加一等于多少", ["一加一等于多少", "一加一等于二"]) | ||
send_request("再加一等于多少", [{"utterance": "一加一等于多少"}, {"utterance": "一加一等于二"}]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里如果启动predictor的默认值为None的时候就是等于模型的最大值。如果用户指定了src len 和 max len,如128 和 64那么最大值就是192了。
我觉得这里写成 self.total_max_length = max(max_position 如4096 , predictor.config.src_length + predictor.config.max_length)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面之前的判断确定predictor.config.src_length + predictor.config.max_length不可能大于max_position,因此这里直接是max_position就可以了。。。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在的逻辑就是这样,你可以看:create_predictor 方法最下面的初始化过程。
所以,在 flask_server 里面出来之后就肯定是初始化好了的。