-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinference.py
100 lines (87 loc) · 3.31 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import argparse
import glob
from typing import Optional
import pandas as pd
from datasets import load_dataset
from determined.experimental import client
from chat_format import maybe_add_generation_prompt
from utils import get_model, get_tokenizer
def main(
exp_id: Optional[int],
trial_id: Optional[int],
device: str,
output_file: Optional[str],
number_of_samples: int,
) -> None:
model_name = "google/gemma-2b"
if exp_id is None and trial_id is None:
checkpoint_dir = model_name
is_base_model = True
else:
exp = client.get_experiment(exp_id) if exp_id else client.get_trial(trial_id)
checkpoint = exp.list_checkpoints(
max_results=1,
sort_by=client.CheckpointSortBy.BATCH_NUMBER,
order_by=client.OrderBy.DESCENDING,
)[0]
checkpoint_dir = checkpoint.download(mode=client.DownloadMode.MASTER)
checkpoint_dir = glob.glob(f"{checkpoint_dir}/checkpoint-*")[0]
is_base_model = False
model = get_model(checkpoint_dir, inference=True, device_map=device)
tokenizer = get_tokenizer(
checkpoint_dir,
truncation_side="right",
model_max_length=8192,
add_eos_token=False,
)
results = {"input": [], "output": [], "correct": []}
dataset = load_dataset(
"Intel/orca_dpo_pairs", split=f"train[1:{number_of_samples}+1]"
)
for element in dataset:
if not is_base_model:
formatted = tokenizer.apply_chat_template(
conversation=[
{
"role": "user",
"content": element["system"] + "\n" + element["question"],
},
],
tokenize=False,
)
formatted = maybe_add_generation_prompt(formatted)
else:
formatted = element["system"] + "\n" + element["question"]
inputs = tokenizer(formatted, return_tensors="pt").to(device)
input_str = tokenizer.batch_decode(inputs["input_ids"])[0]
print(f"Model input: {input_str}")
outputs = model.generate(
**inputs, eos_token_id=tokenizer.eos_token_id, max_new_tokens=1000
)
input_length = inputs["input_ids"].shape[1]
response = tokenizer.batch_decode(
outputs[:, input_length:], skip_special_tokens=True
)
print(f"\n\nCorrect response:\n{element['chosen']}")
print(f"\n\nLLM response:\n{response[0]}")
results["input"].append(input_str)
results["output"].append(response[0])
results["correct"].append(element["chosen"])
if output_file:
df = pd.DataFrame.from_dict(results)
df.to_csv(output_file, index=False)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--exp_id", type=int, default=None, required=False)
parser.add_argument("--trial_id", type=int, default=None, required=False)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--output_file", type=str, default=None, required=True)
parser.add_argument("--number-of-samples", type=int, default=100, required=False)
args = parser.parse_args()
main(
args.exp_id,
args.trial_id,
args.device,
args.output_file,
args.number_of_samples,
)