forked from RLHF-V/RLAIF-V
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat.py
226 lines (190 loc) · 8.78 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import json
import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \
DEFAULT_IMAGE_PATCH_TOKEN
from llava.conversation import conv_templates
from builder.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, process_images
from PIL import Image
import base64
import io
import os
from omnilmm.train.train_utils import omni_preprocess
def init_omni_lmm(model_path):
torch.backends.cuda.matmul.allow_tf32 = True
disable_torch_init()
model_name = os.path.expanduser(model_path)
print(f'Load model and tokenizer from {model_name}')
tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None, model_name)
image_token_len = model.model.config.num_query
return model, image_processor, image_token_len, tokenizer
def expand_question_into_multimodal(question_text, image_token_len, im_st_token, im_ed_token, im_patch_token):
if '<image>' in question_text[0]['content']:
question_text[0]['content'] = question_text[0]['content'].replace(
'<image>', im_st_token + im_patch_token * image_token_len + im_ed_token)
else:
question_text[0]['content'] = im_st_token + im_patch_token * \
image_token_len + im_ed_token + '\n' + question_text[0]['content']
return question_text
def wrap_question_for_omni_lmm(question, image_token_len, tokenizer):
if isinstance(question, str):
question = [{"role": "user", "content": question}]
question = expand_question_into_multimodal(
question, image_token_len, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN)
conversation = question
data_dict = omni_preprocess(sources=[conversation],
tokenizer=tokenizer,
generation=True)
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
return data_dict
class RLAIFV12B:
def __init__(self, model_path) -> None:
model, img_processor, image_token_len, tokenizer = init_omni_lmm(model_path)
self.model = model
self.image_token_len = image_token_len
self.image_transform = img_processor
self.tokenizer = tokenizer
self.model.eval()
def decode(self, image, input_ids, param=None):
with torch.inference_mode():
if param is None:
output = self.model.generate_vllm(
input_ids=input_ids.unsqueeze(0).cuda(),
images=image.unsqueeze(0).half().cuda(),
temperature=0.6,
max_new_tokens=1024,
num_beams=3,
do_sample=True,
output_scores=True,
return_dict_in_generate=True,
repetition_penalty=1.1,
top_k=30,
top_p=0.9,
)
else:
output = self.model.generate_vllm(
input_ids=input_ids.unsqueeze(0).cuda(),
images=image.unsqueeze(0).half().cuda(),
max_new_tokens=1024,
output_scores=True,
return_dict_in_generate=True,
**param
)
response = self.tokenizer.decode(
output.sequences[0], skip_special_tokens=True)
response = response.strip()
return response
def chat(self, input, param=None):
if isinstance(input['image'], str):
im_64 = img2base64(input['image'])
msgs = json.dumps([{"role": "user", "content": input['question']}])
try:
if isinstance(input['image'], str):
image = Image.open(io.BytesIO(base64.b64decode(im_64))).convert('RGB')
else:
image = input['image']
except Exception as e:
return "Image decode error"
msgs = json.loads(msgs)
input_ids = wrap_question_for_omni_lmm(
msgs, self.image_token_len, self.tokenizer)['input_ids']
input_ids = torch.as_tensor(input_ids)
image = self.image_transform(image)
out = self.decode(image, input_ids, param=param)
return out
def img2base64(file_name):
with open(file_name, 'rb') as f:
encoded_string = base64.b64encode(f.read())
return encoded_string
class RLAIFV7B:
def __init__(self, model_path) -> None:
disable_torch_init()
model_name = 'llava-v1.5-7b'
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path, model_base=None, model_name=model_name, device_map={"": 'cuda'})
self.tokenizer = tokenizer
self.model = model
self.image_processor = image_processor
self.context_len = context_len
def chat(self, input):
msgs = input['question']
if self.model.config.mm_use_im_start_end:
msgs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + msgs
else:
msgs = DEFAULT_IMAGE_TOKEN + '\n' + msgs
if isinstance(input['image'], str):
image = Image.open(input['image']).convert('RGB')
else:
image = input['image']
conv = conv_templates["llava_v1"].copy()
conv.append_message(conv.roles[0], msgs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
0).cuda()
image_tensor = process_images([image], self.image_processor, self.model.config)[0]
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
image_sizes=[image.size],
do_sample=False,
temperature=0,
num_beams=3,
max_new_tokens=1024,
use_cache=True)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return outputs
class RLAIFV7BLoRA:
def __init__(self, model_path, model_base) -> None:
disable_torch_init()
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path, model_base=model_base, model_name='llava_lora_model', device_map={"": 'cuda'})
self.tokenizer=tokenizer
self.model=model
self.image_processor=image_processor
self.context_len=context_len
def chat(self, input):
msgs = input['question']
if self.model.config.mm_use_im_start_end:
msgs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + msgs
else:
msgs = DEFAULT_IMAGE_TOKEN + '\n' + msgs
image = Image.open(input['image']).convert('RGB')
conv = conv_templates["llava_v1"].copy()
conv.append_message(conv.roles[0], msgs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
image_tensor = process_images([image], self.image_processor, self.model.config)[0]
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=image_tensor.unsqueeze(0).half().cuda(),
image_sizes=[image.size],
do_sample=False,
temperature=0,
num_beams=3,
max_new_tokens=1024,
use_cache=True)
outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
return outputs
class RLAIFVChat:
def __init__(self, model_path) -> None:
if '12B' in model_path:
self.model = RLAIFV12B(model_path)
elif '7B' in model_path:
self.model = RLAIFV7B(model_path)
elif 'lora_checkpoint' in model_path:
self.model = RLAIFV7BLoRA(model_path, model_base='liuhaotian/llava-v1.5-7b')
def chat(self, input):
return self.model.chat(input)
if __name__ == '__main__':
chat_model = RLAIFVChat('RLAIF-V/RLAIF-V-7B/lora_checkpoints') # or 'HaoyeZhang/RLAIF-V-12B'
image_path = "./examples/test.jpeg"
msgs = "Why did the car in the picture stop?"
inputs = {"image": image_path, "question": msgs}
answer = chat_model.chat(inputs)
print(answer)