-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
matatonic
committed
Apr 5, 2024
1 parent
4827ec0
commit 3401b8e
Showing
18 changed files
with
556 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
|
||
from vision_qna import * | ||
|
||
class VisionQnA(VisionQnABase): | ||
model_name: str = "generic" | ||
|
||
def __init__(self, model_id: str, device: str, extra_params = {}, format = None): | ||
super().__init__(model_id, device, extra_params, format) | ||
|
||
if not format: | ||
self.format = guess_model_format(model_id) | ||
|
||
self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=self.params.get('trust_remote_code', False)) | ||
self.model = AutoModelForCausalLM.from_pretrained(**self.params).eval() | ||
|
||
print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}") | ||
|
||
async def chat_with_images(self, messages: list[Message], max_tokens: int) -> str: | ||
images, prompt = await prompt_from_messages(messages, self.format) | ||
|
||
encoded_images = self.model.encode_image(images) | ||
inputs = self.tokenizer(prompt, encoded_images, return_tensors="pt") | ||
output = self.model.generate(**inputs, max_new_tokens=max_tokens) | ||
response = self.tokenizer.decode(output[0], skip_special_tokens=True) | ||
|
||
return answer_from_response(response, self.format) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import re | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
|
||
from minigemini.model.builder import load_pretrained_model | ||
from minigemini.mm_utils import process_images | ||
|
||
from vision_qna import * | ||
|
||
class VisionQnA(VisionQnABase): | ||
model_name: str = "minigemini" | ||
format: str = "llama2" | ||
|
||
def __init__(self, model_id: str, device: str, extra_params = {}, format = None): | ||
super().__init__(model_id, device, extra_params, format) | ||
|
||
if not format: | ||
self.format = guess_model_format(model_id) | ||
|
||
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( | ||
args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) | ||
|
||
|
||
print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}") | ||
|
||
async def chat_with_images(self, messages: list[Message], max_tokens: int) -> str: | ||
images, prompt = await prompt_from_messages(messages, self.format) | ||
|
||
#encoded_images = self.model.encode_image(images).to(self.device) | ||
# square? | ||
image_tensor = process_images(image_convert, image_processor, model.config) | ||
image_processor(images, return_tensors='pt')['pixel_values'] | ||
|
||
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device) | ||
|
||
with torch.inference_mode(): | ||
output_ids = self.model.generate( | ||
input_ids, | ||
images=image_tensor, | ||
images_aux=None, | ||
do_sample=False, | ||
temperature=0.0, | ||
max_new_tokens=max_tokens, | ||
bos_token_id=self.tokenizer.bos_token_id, # Begin of sequence token | ||
eos_token_id=self.tokenizer.eos_token_id, # End of sequence token | ||
pad_token_id=self.tokenizer.pad_token_id, # Pad token | ||
use_cache=True) | ||
|
||
answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() | ||
|
||
self. | ||
|
||
return answer | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import os | ||
import uuid | ||
from transformers import AutoTokenizer, AutoModelForCausalLM | ||
|
||
from vision_qna import * | ||
|
||
# echo840/Monkey | ||
|
||
class VisionQnA(VisionQnABase): | ||
model_name: str = "monkey" | ||
format: str = 'qwen' # phi15-ish | ||
|
||
def __init__(self, model_id: str, device: str, extra_params = {}, format = None): | ||
super().__init__(model_id, device, extra_params, format) | ||
|
||
# XXX currently bugged https://huggingface.co/echo840/Monkey/discussions/4 | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=self.params.get('trust_remote_code', False)) | ||
self.model = AutoModelForCausalLM.from_pretrained(**self.params).eval() | ||
|
||
self.tokenizer.padding_side = 'left' | ||
self.tokenizer.pad_token_id = self.tokenizer.eod_id | ||
|
||
print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}") | ||
|
||
async def chat_with_images(self, messages: list[Message], max_tokens: int) -> str: | ||
files = [] | ||
prompt = '' | ||
|
||
for m in messages: | ||
if m.role == 'user': | ||
p = '' | ||
for c in m.content: | ||
if c.type == 'image_url': | ||
filename = await url_to_file(c.image_url.url) | ||
p = '<img>' + filename + '</img> ' + p | ||
if c.type == 'text': | ||
p += f"{c.text}\n\n" # Question: | ||
prompt += p | ||
elif m.role == 'assistant': | ||
for c in m.content: | ||
if c.type == 'text': | ||
prompt += f"Answer: {c.text}\n\n" | ||
|
||
prompt += "Answer:" | ||
|
||
input_ids = self.tokenizer(prompt, return_tensors='pt', padding='longest') | ||
|
||
attention_mask = input_ids.attention_mask.to(self.model.device) | ||
input_ids = input_ids.input_ids.to(self.model.device) | ||
|
||
pred = self.model.generate( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
do_sample=False, | ||
num_beams=1, | ||
max_new_tokens=512, | ||
min_new_tokens=1, | ||
length_penalty=1, | ||
num_return_sequences=1, | ||
output_hidden_states=True, | ||
use_cache=True, | ||
pad_token_id=self.tokenizer.eod_id, | ||
eos_token_id=self.tokenizer.eod_id, | ||
) | ||
response = self.tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip() | ||
|
||
for f in files: | ||
os.remove(f) | ||
|
||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.