-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp_internlm.py
123 lines (90 loc) · 4.07 KB
/
app_internlm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from modelscope.hub.snapshot_download import snapshot_download
from dataclasses import asdict
import streamlit as st
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import logging
from tools.transformers.interface import GenerationConfig, generate_interactive
model_id = 'baiyu96/career_coach'
model_dir = snapshot_download(model_id, revision='v0.3.0')
logger = logging.get_logger(__name__)
def on_btn_click():
del st.session_state.messages
@st.cache_resource
def load_model():
model = (
AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
return model, tokenizer
def prepare_generation_config():
with st.sidebar:
"[InternLM](https://github.com/InternLM/InternLM.git)"
"[career_coach](https://github.com/BaiYu96/career_coach)"
"[career_coach_model](https://www.modelscope.cn/models/baiyu96/career_coach/summary)"
max_length = st.slider("Max Length", min_value=32, max_value=2048, value=2048)
top_p = st.slider("Top P", 0.0, 1.0, 0.8, step=0.01)
temperature = st.slider("Temperature", 0.0, 1.0, 0.7, step=0.01)
st.button("Clear Chat History", on_click=on_btn_click)
generation_config = GenerationConfig(max_length=max_length, top_p=top_p, temperature=temperature)
return generation_config
user_prompt = "<|User|>:{user}\n"
robot_prompt = "<|Bot|>:{robot}<eoa>\n"
cur_query_prompt = "<|User|>:{user}<eoh>\n<|Bot|>:"
def combine_history(prompt):
messages = st.session_state.messages
total_prompt = ""
for message in messages:
cur_content = message["content"]
if message["role"] == "user":
cur_prompt = user_prompt.replace("{user}", cur_content)
elif message["role"] == "robot":
cur_prompt = robot_prompt.replace("{robot}", cur_content)
else:
raise RuntimeError
total_prompt += cur_prompt
total_prompt = total_prompt + cur_query_prompt.replace("{user}", prompt)
return total_prompt
def main():
# torch.cuda.empty_cache()
print("load model begin.")
model, tokenizer = load_model()
print("load model end.")
user_avator = "./imgs/user.png"
robot_avator = "./imgs/cc-2.png"
st.title("职场教练")
generation_config = prepare_generation_config()
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"], avatar=message.get("avatar")):
st.markdown(message["content"])
# Accept user input
if prompt := st.chat_input("What is up?"):
# Display user message in chat message container
with st.chat_message("user", avatar=user_avator):
st.markdown(prompt)
real_prompt = combine_history(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt, "avatar": user_avator})
with st.chat_message("robot", avatar=robot_avator):
message_placeholder = st.empty()
for cur_response in generate_interactive(
model=model,
tokenizer=tokenizer,
prompt=real_prompt,
additional_eos_token_id=103028,
**asdict(generation_config),
):
# Display robot response in chat message container
message_placeholder.markdown(cur_response + "▌")
message_placeholder.markdown(cur_response)
# Add robot response to chat history
st.session_state.messages.append({"role": "robot", "content": cur_response, "avatar": robot_avator})
torch.cuda.empty_cache()
if __name__ == "__main__":
main()