forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmeasure_ppl2_llama2_MC.py
executable file
·204 lines (166 loc) · 6.89 KB
/
measure_ppl2_llama2_MC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
#!/usr/bin/env python3
"""
This is a script that produces a realistic PPL measurement
for the quantized KV cache system by processing a sequence of
non-overlapping patches of the reference text. Generation of the
consecutive symbols in each patch is governed (forced)
by the reference text.
The initial context size for the system is set by the parameter
"--context-size".
The number of output symbols to generate starting from a given
context is set by the parameter "--sample-size". This variable also
defines the size of the individual patch.
For the N-token reference text that is split into M patches with the
system's context size C it takes M*preload + (N-C)*generation time.
Quick correctness validation tips:
Running llama-2-7b model
(
./vllm/benchmarks/measure_ppl2_MC.py
--model=/data/models/llama-2-7b-chat-hf
--data=./vllm/tests/prompts/wiki.test.raw
--context-size=1024
--sample-size=512
)
should result in PPL ~ 6.524227946419175
Running llama-2-7b model
(
./vllm/benchmarks/measure_ppl2_MC.py
--model=/data/models/llama-2-7b-chat-hf
--data=./vllm/tests/prompts/wiki.test.raw
--context-size=1024
--sample-size=512
--patch-size=1
)
should result in PPL ~ PPL=3.8968611189957523
"""
import argparse
import datetime
import math
from transformers import LlamaTokenizer
from vllm import LLM, SamplingParams
from vllm.logger import init_logger
logger = init_logger(__name__)
def get_wikitext2_text(tokenizer):
with open(args.data) as f:
test_text = "\n".join(line.strip() for line in f)
test_enc = tokenizer(test_text)
return test_enc, test_text
def vllm_init(args):
llm = LLM(
model=args.model,
tokenizer=None,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
kv_cache_dtype=args.kv_cache_dtype,
#scales_path=args.kv_cache_scales_path
# if args.kv_cache_scales_path!='' else None,
quantization_param_path=args.kv_cache_scales_path
if args.kv_cache_scales_path != '' else None,
enforce_eager=args.enforce_eager)
sampling_params = SamplingParams(n=1,
temperature=0.0,
top_p=1,
use_beam_search=False,
ignore_eos=True,
ppl_measurement=True,
future_context=[],
presence_penalty=0.0)
return llm, sampling_params
def vllm_predict(CONT, llm, sampl_par):
result = llm.generate(prompt_token_ids=CONT, sampling_params=sampl_par)
return result
def main(args: argparse.Namespace):
MESSAGE = f"Initialising @ {datetime.datetime.now()}"
logger.info(MESSAGE)
print(MESSAGE)
my_ppl = 0.0
my_tokenizer = LlamaTokenizer.from_pretrained(args.model)
logger.info("Loaded the tokenizer.")
logger.info("Initializing the engine.")
my_llm, my_sampl_par = vllm_init(args)
logger.info(my_sampl_par)
logger.info("Initialized the engine.")
my_test_enc, my_test_text = get_wikitext2_text(my_tokenizer)
logger.info("Loaded the test data.")
my_n_samples = args.sample_size
my_n_patches = math.ceil(
(len(my_test_enc['input_ids']) - args.context_size - 1) / my_n_samples)
if args.patch_size is not None:
my_n_patches = args.patch_size
num_tokens_generated = 0
starting_time = datetime.datetime.now()
MESSAGE = f"Starting generation @ {starting_time} \
will try to process {my_n_patches} patche(s), \
generating {my_n_samples} tokens in each patch \
from the initial context of {args.context_size} tokens."
logger.info(MESSAGE)
print(MESSAGE)
for c in range(my_n_patches):
CONTEXT = []
my_sampl_par.future_context = []
CONTEXT.append(
my_test_enc['input_ids'][c * my_n_samples:c * my_n_samples +
args.context_size])
upper_boundary = min((c + 1) * my_n_samples + args.context_size,
len(my_test_enc['input_ids']))
my_sampl_par.future_context.append(
my_test_enc['input_ids'][c * my_n_samples +
args.context_size:upper_boundary])
my_sampl_par.max_tokens = len(my_sampl_par.future_context[0])
my_sampl_par.cntr = c
LOGPROBS = vllm_predict(CONTEXT, my_llm, my_sampl_par)
num_tokens_generated += len(LOGPROBS[0].outputs[0].token_ids)
my_ppl -= LOGPROBS[0].outputs[0].cumulative_logprob
MESSAGE = f"Iteration {c+1} of {my_n_patches} Intermediate \
Estimates:\n\
\tCross-entropy_intermediate={my_ppl/num_tokens_generated}\n\
\tPerplexity_intermediate={math.exp(my_ppl/num_tokens_generated)}"
logger.info(MESSAGE)
print(MESSAGE)
ending_time = datetime.datetime.now()
MESSAGE = f"Done @ {ending_time} after processing for \
{ending_time-starting_time} generated {num_tokens_generated} tokens."
logger.info(MESSAGE)
MESSAGE = f"Integral Cross-Entropy={my_ppl} Average Cross-Entropy=\
{my_ppl/num_tokens_generated} PPL={math.exp(my_ppl/num_tokens_generated)}"
logger.info(MESSAGE)
print(MESSAGE)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument(
'--data',
type=str,
default='./wikitext/wikitext-2-v1/test-00000-of-00001.parquet')
parser.add_argument('--context-size', type=int, default=4096)
parser.add_argument('--kv-cache-scales-path', type=str, default='')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--sample-size', type=int, default=512)
parser.add_argument('--patch-size', type=int, default=None)
parser.add_argument('--enforce-eager',
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=['auto', 'fp8_e5m2', 'fp8'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type.')
args = parser.parse_args()
main(args)