-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathinference.py
96 lines (91 loc) · 2.59 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import argparse
from inference_pipeline import InferencePipeline
def args_init():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_path",
type=str,
default="../data/ClassEval_data.json",
help="ClassEval data",
)
parser.add_argument(
"--greedy",
type=int,
default=1,
help="Whether to generate model results with greedy strategy",
)
parser.add_argument(
"--output_path",
type=str,
default="model_output.json",
help="output file path",
)
parser.add_argument(
"--cuda",
type=int,
nargs="+", # Accept one or more integers
default=None,
help="List of CUDA device(s), default value is None. If not set, use all available devices.",
)
parser.add_argument(
"--generation_strategy",
type=int,
default=0,
help="Holistic = 0, Incremental = 1, Compositional = 2",
)
parser.add_argument(
"--model",
type=int,
default=1,
help="Instruct_CodeGen = 0, WizardCoder = 1, Instruct_StarCoder = 2, InCoder = 3, \
PolyCoder = 4, SantaCoder = 5, Vicuna = 6, ChatGLM = 7, GPT_3_5 = 8, GPT_4 = 9, others = 10, \
Magicoder = 11, CodeGeeX2 = 12, DeepSeekCoder_inst = 13, Gemini_Pro = 14, CodeLlama_13b_inst = 15",
)
parser.add_argument(
"--checkpoint",
type=str,
default="WizardLM/WizardCoder-15B-V1.0",
help="checkpoint of the model",
)
parser.add_argument(
"--temperature",
type=float,
default=0.2,
help="temperature value in generation config",
)
parser.add_argument(
"--max_length",
type=int,
default=2048,
help="max length of model's generation result",
)
parser.add_argument(
"--openai_key",
type=str,
default="openai_key",
help="need openai key if use GPT-3.5 or GPT-4",
)
parser.add_argument(
"--openai_base",
type=str,
default="openai_base",
help="need openai base if use GPT-3.5 or GPT-4",
)
parser.add_argument(
"--google_api_key",
type=str,
default="google_api_key",
help="need google api key if use Gemini Pro",
)
parser.add_argument(
"--sample",
type=int,
default=5,
help="The number of code samples that are randomly generated for each task.",
)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = args_init()
infer = InferencePipeline(args)
infer.pipeline()