-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinterface.py
80 lines (68 loc) · 1.96 KB
/
interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import torch
import transformers
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer
import json
import numpy as np
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
llama_path = "path/to/your/llama/checkpoint"
checkpoint_path = "checkpoint"
data_path = "data/mPLUG_caption.jsonl"
tokenizer = LlamaTokenizer.from_pretrained(llama_path)
model = LlamaForCausalLM.from_pretrained(
llama_path,
load_in_8bit=False,
torch_dtype=torch.float16,
device_map="auto",
)
model = PeftModel.from_pretrained(
model,
checkpoint_path,
force_download=True,
torch_dtype=torch.float16,
)
model.config.pad_token_id = tokenizer.pad_token_id = 0
model.config.bos_token_id = 1
model.config.eos_token_id = 2
model.eval()
if torch.__version__ >= "2":
model = torch.compile(model)
data = []
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
data.append(json.loads(line.strip()))
generation_config = GenerationConfig(
temperature=0.1,
top_p=0.75,
top_k=40,
num_beams=4,
)
for d in data:
image_id = d["image"]
captions = d["reference caption"]
prompt = ""
prompt += "reference captions:\n"
for caption in captions:
prompt += caption + ' '
prompt = prompt[:-1]
prompt += "\nour caption:\n"
prompt += d["output"].split('\n')[0].strip()
prompt += "\nIs our caption accurate?\n"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=1,
)
sentence = generation_output.sequences
sentence = tokenizer.decode(sentence.tolist()[0], skip_special_tokens=True)
result = sentence.split("\n")[-1]
print(result)