-
Notifications
You must be signed in to change notification settings - Fork 10.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
server : add speculative decoding support #10455
Conversation
1973399
to
7dc6ae5
Compare
From what I have read the goal is faster inference while retaining quality of the larger model. I am using rx6900xt with vulkan I get about 10-12 t/s with an incorrect configuration.
Flipping the models increased speed and the output looks similar. This makes sense since the -md is the draft model which is supposed to be the smaller model. I get about 16 t/s with the correct configuration.
Setting a lower context 2048, when the limit is reached the server crashed. |
c5ddee2
to
e80f758
Compare
@3Simplex What is the output of the following bench on your machine: llama-bench.exe -m "...Qwen2.5-Coder-7B-Instruct-Q8_0.gguf" -p 1,1,2,3,4,5,6,7,8,12,16,32 -r 20 -n 0 -ngl 99 -fa 1 |
.\llama-bench.exe -m "...\Qwen2.5-Coder-7B-Instruct-Q8_0.gguf" -p 1,1,2,3,4,5,6,7,8,12,16,32 -r 20 -n 0 -ngl 99 -fa 1
build: 0c74590 (4160) |
I tried out commit e80f758 with my P40s, 3xP40s and 3090. These are the commands for the baselines and the tests. Baseline:
With speculative model (just removed the
Tested it with curl using:
Data:
|
e80f758
to
d905266
Compare
Currently, it requires
The biggest benefit from speculative sampling is when you have more grounding. For example, if you have enough memory for a bigger context, you can try something like this: # get the llama.vim plugin source code
code=$(curl -s https://raw.githubusercontent.com/ggml-org/llama.vim/refs/heads/master/autoload/llama.vim | jq -sRr @json)
# ask qwen to implement something (speculative decoding disabled)
curl --request POST --url http://localhost:8033/v1/chat/completions -H "Content-Type: application/json" -H "Authorization: Bearer no-key" -d "$(jq -n --arg code "$code" \
'{ messages: [{ role: "system", content: "You are an expert computer scientist. Respond only with code blocks. Do not add any other comments except code." }, { role: "user", content: "Suggest an improvement for the `chunk_sim` function using Levenstein distance: ```\($code)```" }], cache_prompt: true, top_k: 1, samplers: ["top_k"], "speculative.n_max": 0 }')" | jq -r .choices[0].message.content
# speculative decoding enabled
curl --request POST --url http://localhost:8033/v1/chat/completions -H "Content-Type: application/json" -H "Authorization: Bearer no-key" -d "$(jq -n --arg code "$code" \
'{ messages: [{ role: "system", content: "You are an expert computer scientist. Respond only with code blocks. Do not add any other comments except code." }, { role: "user", content: "Suggest an improvement for the `chunk_sim` function using Levenstein distance: ```\($code)```" }], cache_prompt: true, top_k: 1, samplers: ["top_k"], "speculative.n_max": 16 }')" | jq -r .choices[0].message.content With CUDA, you might want to try setting |
Thank you for the guidance. Using d905266, I reran the tests. Results look quite good.
Server command:
Kept this pretty consistent, except for the 3xP40 run where I added Client side:
For the client side curl, I changed Here are the raw results. Some observations first:
3090 data
single P40
3xP40 (-sm row)
Code generated: function! s:chunk_sim(c0, c1)
let l:lines0 = join(a:c0, "\n")
let l:lines1 = join(a:c1, "\n")
let l:distance = levenshtein(l:lines0, l:lines1)
return 1 - (l:distance / max([strlen(l:lines0), strlen(l:lines1)]))
endfunction
function! levenshtein(s1, s2)
let l:len1 = strlen(a:s1)
let l:len2 = strlen(a:s2)
if l:len1 == 0
return l:len2
endif
if l:len2 == 0
return l:len1
endif
let l:dp = []
for i in range(l:len1 + 1)
call add(l:dp, [])
for j in range(l:len2 + 1)
call add(l:dp[i], 0)
endfor
endfor
for i in range(l:len1 + 1)
let l:dp[i][0] = i
endfor
for j in range(l:len2 + 1)
let l:dp[0][j] = j
endfor
for i in range(1, l:len1 + 1)
for j in range(1, l:len2 + 1)
let l:cost = (strcharpart(a:s1, i - 1, 1) == strcharpart(a:s2, j - 1, 1)) ? 0 : 1
let l:dp[i][j] = min([l:dp[i - 1][j] + 1, l:dp[i][j - 1] + 1, l:dp[i - 1][j - 1] + l:cost])
endfor
endfor
return l:dp[l:len1][l:len2]
endfunction |
Also, is |
Thanks for the detailed tests. The results are inflated because there is one tricky side effect from the caching - consecutive runs with the same prompt will reuse the previous draft context which combined with greedy sampling would make the drafting instantaneous. So basically, in the following data for example, only the first result is relevant:
i.e.
This was a bug - it is fixed now. You should be able to change Btw, here is another fun test that I came up with which uses less context and is suitable for speculation: # get top 10 stories from Hacker News
hn=$(curl -s https://hacker-news.firebaseio.com/v0/topstories.json | jq -r '.[:10] | @tsv' | tr '\t' '\n' | xargs -I {} curl -s "https://hacker-news.firebaseio.com/v0/item/{}.json" | jq -sRr @json)
# make a Markdown table based on some criteria
curl --request POST --url http://localhost:8033/v1/chat/completions -H "Content-Type: application/json" -H "Authorization: Bearer no-key" -d "$(jq -n --arg hn "$hn" \
'{ messages: [{ role: "system", content: "You are a helpful text-editing assistant. Respond only with the requested text. Do not add any other comments to your response." }, { role: "user", content: "Extract a Markdown table that contains only stories about software engineering, AI or machine learning from the front-page of HN. The table should include: author, title, score, comments and an URL to the story: ```\($hn)```." }], cache_prompt: true, top_k: 1, samplers: ["top_k"], "speculative.n_max": 16 }')" | jq -r .choices[0].message.content |
Thanks. That seems a lot more realistic. I did some tests with a much shorter prompt: "write snake game in swift"
|
These numbers look reasonable. The speedup can vary in both ways based on the inputs, but enabling speculative should almost never result in slower than normal decoding. |
With this build I am up to 25t/s on first run generation with speculative decoding using 15/5 draft tokens. |
A bit of data with llama-3.1 70B and llama-3.2 1B as the draft model. Prompt: "write a story about the natural resources in Canada".
Server:
client (changed speculative.n_max between
|
Note that I am not very sure what happens with multiple GPUs, but it is possible that the draft model gets split across them, which is not desired (see the logs if that is the case). You would want to keep the draft model fully on one GPU. |
c277c4d
to
156aa6d
Compare
I wonder if it is possible to load draft and main model onto different backend. Ie a 7900xtx and P40 in a -cb process |
@dagbdagb
repository
AMD Ryzen 9 7940HS w/ Radeon 780M Graphics
Total 32.0GB available 27.8 GB (4GB for iGPU)
extract specific data from around 1500 tokens of text in Japanese (repeat 26 times)
(1)b4219(llama.cpp official binary)
5764.07 second
4968.42 second (2)locally built myself(b4227)
5807.13 second
5003.03 second (3)ROCm (b4215)
1576.67 second I feel that the 2B model may not be able to run fast enough on the CPU. This causing a bottleneck. |
Tried it with Qwen-2.5 on my 2x 3090s. No performance improvements whatsoever with 72b split across both GPUs. Lost some performance, actually. Ran a bunch of experiments using different hints I picked up here. No performance gains still, running the 14b variant on the same gpu as draft models (tried 0.5, 1.5b, 3b) or the other gpu, any permutation of draft-p-min and speculative.n_max. Best I got was 2-3 tps more (around 56 tps) as compare to ~54 I'm getting running without the draft model.
Maybe I'm missing something obvious, but no clue how other folks got such huge performance gains. |
I tested b4240 with 150% speed bump which must be an optimal use case with my non-optimal hardware (16GB P5000 + 8GB RTX 2070 Max-Q). These tweaks contributed to that: draft-min 0, draft-p-min 0.5 and temperature 0.1.
|
@mybyte Did you remember to set |
@JeroenAdam This should be fixed now (#10586) |
Yup. Tried both, UI and my app via REST. I wonder if a potential culprit is the model choice? |
First - wow ; is this some upgrade- Hats off. So I "mixed" two versions of the same model (llama3, 8B) , one reg quant, one imatrix quant (draft): `` `` Using a custom sampler profile (includes use of DRY, temp .84, and penalty samplers) ... here is part output from a ONE LINE prompt: _I stood frozen at my desk, staring into the abyss of the office park outside our high-rise building. My fingers were still poised above the keyboard, the cursor blinking mockingly at me like some malignant eye. The air was thick with tension, the kind that precedes an announcement no one wants to hear. But this was different. This was personal. As I gazed out over the sea of cubicles and fluorescent lights, something began to rumble beneath my feet, like thunder growing nearer. A low, ominous hum filled my ears, making my teeth ache. Suddenly, the sky scraper swayed, as if buffeted by a storm force wind. I grabbed onto the edge of my desk for support, my heart racing with alarm. My coworkers gasped and screamed as we all stumbled about, our chairs scraping against tile floors. Glass shattered somewhere nearby, and a panicked shout rose up from below. It felt like a seismic event, a catastrophic shift in the earth's crust that had nothing to do with geology. Then came the blast._ You don't get this level of detail at 8B. Or prose, or metaphors. This is like a MOE on steroids. A lot of steroids. |
@David-AU-github I don't think you quite understand how speculative decoding works. It will generate identical results to the non-speculative decoding case and will always generate what the main model would have generated on its own. It's only useful as a speed boost, it will not alter the output at all. |
@Mushoz Also - number of "drafts" ? Great variance here in output here. I also tested this model - both main and draft - separately to see if I could replicate this level of detail. Note : The two models - even thou the same model - one is imatrix version and the other non-imatrix version. In static tests (temp=0) each model will output different content from the same prompt. |
Normally a model can only predict one token at a time, because the token at position N depends on all previous tokens 0 through N-1. It would be much quicker if a model could predict not only token at position N, but also (depending on number of drafts) N+1, N+2, N+3. The reason why this is much faster, is because all the weight data of the big slow model only needs to be retrieved once for all 4 tokens, and LLMs are generally memory bandwidth limited. More calculations need to be done, but GPUs are extremely good at parallel computations which is what this is. But this cannot normally be done, because you need all previous tokens to be able to generate the next. What the draft model does, is generate a sequence of draft tokens N, N+1, N+2. The big model then assumes these to be true and generates 1 token ahead of each of these draft token, so it can do multiple at the same time. That means that despite the draft model generating N, N+1, N+2, the big model still generates these as well to verify them, but is able to do so in parallel (fast) instead of in sequence as is done in normal generation. If the base model generates a different token than what the draft model predicted, all subsequent tokens are discarded and the drafting is started all over again. This means that tokens are only retained if the draft model predicted the token the base model generated, which is why the output in speculative decoding is identical to what the base model would have generated in non speculative decoding. And this is also why a speed up is only observed if the predictions are good enough, because if not, all the extra work is simply discarded. |
Thank you. To clarify ; you get a speed increase in the vast majority of draft sequence tokens are "in agreement" between the draft and main model. The "draft" min / max is the number of tokens to generate for sequence? -> that is the min/max size of sequence? If the draft sequence / token(s) are not in agreement do both models "go back to the drawing board" and both models "redraft" a sequence of tokens? If this point is true - specifically both models - , that explains what I am observing and I can work with that. It sounds like when I am err... using speculative decoding in this way it is forcing different choices to occur than would otherwise happen. Almost like a strange version of temp and/or a "light" rep pen sampler ? or adding a random element into the generation? I have tested this method with other models / archs too and observing an increase in generational quality, with a decrease in T/S. |
The maximum amount of tokens to draft, is just that: How long of a sequence the draft model will draft. The higher this value is, the higher potential speed increase (up to a maximum, where you become compute bound), as long as the predictions are correct. For longer sequences, the draft model will most definitely generate something else than the main model. That means there is a sweet spot somewhere. Too high and you're just wasting work that is discarded, leading to slowdowns. The draft min is a variable that will tune how long your draft sequence needs to be at a minimum before the main model uses the draft predictions to do the final predictions. Some GPUs might not be terribly efficient at certain batch sizes, so it might be better to force them to higher batch sizes where the kernels are better optimized for batch processing. When the main model is in disagreement, all draft tokens are discarded and all tokens generated by the main model that were BASED ON THE INCORRECT DRAFT TOKEN(S) are discarded as well. Importantly, the token that was generated by the main model that proved the draft wrong is NOT discarded, and the main model essentially falls back to the normal non-speculative decoding case. Again, the main model will generate the exact same tokens with speculative decoding on vs off. The differences you are observing are purely due to sampler settings. Speculative decoding does not alter the output in any way, and anything you believe you are seeing is merely placebo. |
Excellent. thank you. RE: How long does the model fall back into "normal non-speculative decoding" operation? Until the next sequence of draft tokens from the draft model? |
It doesn't really fall back in the literal sense. What I mean is that the draft tokens that were generated incorrectly are simply ignored as if speculative decoding had never been turned on in the first place. Speculative decoding will remain effective in the sense that the draft model will immediately generate a new draft sequence after getting corrected by the main model and the main model will then again use that sequence to do the validation, just as it had been doing before. |
Note that this only applies to the incorrect draft token itself, and all subsequent tokens (as they are based on an incorrect preceding token). All correct draft tokens before the incorrect one are retained of course. If a draft sequence is 16 tokens long, it's perfectly possible the first 8 tokens are correct (which are retained) and the 9th is incorrect, which means token 9 through 16 of the sequence are discarded. |
@Mushoz Thank you for your help with this. There is an interesting divergence for creative use with very low bit quants VS mid/high which may benefit or be a benefit. (this is separate and part from spec decoding). Hmmm. Never mind two different models all together (with same vocab)... hmmm 2x. MOEs... raise even more questions. |
I can't seem to get any performance gains on my Mac. I ordered a brand new M4 Max 128GB to get the most out of it. I'm runnign a Q4 K-M version of L3.3 70bn as the verification model and have tried a Q4 K-M version of L3.1 8bn, L3.2 3bn and L3.2 1bn as the drafting model and in no constellation do I get any benefit. We've seen significant benefits with Nvidia hard- and software, but I'd like to see the same on Metal. I'm using this command on the server: ./build/bin/llama-server -m ../models/verification.gguf -md /Users/mattsinalco/.cache/huggingface/hub/models--unsloth--Llama-3.2-1B-Instruct-GGUF/snapshots/a5594fb18df5dfc6b43281423fcce6750cd92de5/Llama-3.2-1B-Instruct-Q4_K_M.gguf -ngl 99 -ngld 90 --port 8033 -c 4096 --draft-min 5 --draft-max 16 --temp 0.0 --draft-p-min 0.5 And this on the client: "cache_prompt": true, WIth just the 70bn verification model, I'm getting 8.7 t/s, but with the setup described here I get 7.99 and it doesn't really mattter which drafting model I use. Is there a problem with Metal, @ggerganov ? |
Could be related to this: #10581 (speculative decoding not yielding benefits for quantized models). |
Are you using latest This is the config I am using and it works pretty good on M2 Ultra: ./build-chat/bin/llama-server \
-m ./models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \
-md ./models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \
--log-file ./service-chat.log \
--host 0.0.0.0 --port 8013 \
--ctx-size 0 \
--cache-reuse 256 \
-ub 4096 -b 4096 -ngl 99 -ngld 99 -fa -dt 0.1 -lv 1 -t 1 --draft-max 16 --draft-min 5 |
* server : add speculative decoding support ggml-ci * server : add helper function slot.can_speculate() ggml-ci
Thanks, @ggerganov. I'm on the latest build. I use this command: ./build/bin/llama-server And this on the client: "top_k": 1, And the result varies quite a bit - from slower to same to a little faster than just running the Q4 K-M 70bn directly. I copy some of the logging below. I can't read the logs, but from what I can tell, the speculative decoding part is working, but the draft_candidates have a lot of 0s and 1s, which I assume are probabilities - is this how it should be? slot update_slots: id 0 | task 118 | new prompt, n_ctx_slot = 4096, n_keep = 0, n_prompt_tokens = 2104 ' slot process_toke: id 0 | task 118 | n_decoded = 1, n_remaining = -1, next token: 5018 '{"'
|
BTW, @ggerganov: From what I can tell, llama-server doesn't support multi-slot KV caches right now, allowing different prompts to maintain separate caches simultaneously. Is that right? This feature would go a long way to speeding up function calling on Metal. Potentially with disk offloading, although I suspect most agentic apps should get by on a dozen different base prompts. I've seen references to this for llama-cli, but I don't think llama-server supports this. I've asked one of my team members to look into this and make a proposal for adding it. |
Hmmm... the zero/one probabilities I got when I ran a 4-bit 8bn drafting model that was fine-tuned on the same data as the verification model. If I use a non-finetuned 1bn or 3bn model, the probabilities vary more, but still no sustained speed-up (just sometimes). Next step: Fine-tune the 1bn model. That shoudl give me more speed than the 8bn bu thten hopefully better proability distributions. |
Okay, so this is already addresse (and solved) here: #9135 |
@ggerganov - I've finally got some good news to report. As reported previously, speculative decoding had almost no effect on my M4 Max. I reliably got 8.7 t/s for Llama 3.3 70bn with a 3bn drafting model. But when I started to crank up the -np value, t/s went up significantly. It maxed out at -np 7, when I reliably got 16+ t/s (a doubling of what I had originally, which is amazing). When I went to -np 8, performance cratered and I ended up with 7 t/s. Cranking up the -np value when I run JUST the 70bn model with a drafting model has no effect, so it's definitely an enabler for speculative decoding, at least on my setup. Next up: Getting KV caching to work. With no -np parameter, the system prompt gets cached. With the -np parameter, I see no such effect. I will now try to force the server to use pre-populated KV slots. Will report back when I have that working. |
I'm confused. Shouldn't the number of slots available for KV caching be independent of whatever causes spec. decoding to speed up when I set -np to 6 or 7? |
target #10362
Initial implementation that enables speculative decoding in
llama-server
. Test with this command:--draft-max
and--draft-min
might need tuningllama.cpp
Web UI clientTop K = 1
-devd
argument to put the draft model on only one of them (llama : accept a list of devices to use to offload a model #10497)Feedback is appreciated.
TODO:
server.params
to something else to avoid confusions