-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
inconsistent logits for identical inputs in batch (metal) #1941
Comments
All else being otherwise equal, this encourages the beam candidate selection to re-use the same decoder, which slightly reduces the cache size. I wouldn't expect it to make much of a performance difference, but it helps when debug printing the cache and beam. Added as part of understanding ggerganov#1941.
Oh, and this reproduces with |
All else being otherwise equal, this encourages the beam candidate selection to re-use the same decoder, which slightly reduces the cache size. I wouldn't expect it to make much of a performance difference, but it helps when debug printing the cache and beam. Added as part of understanding #1941.
In ggerganov/llama.cpp#4130 (comment) In short, we put the tokens from all sequences in the same KV cache buffer and construct a suitable KQ-mask that is used to discard cross-sequence attention. This is in contrast to a more straightforward approach in which each sequence has it's own separate KV cache buffer and the attention for each sequence is computed independently from the rest - the KQ mask in this case is simply a causal mask. A drawback of the unified KV cache is that the results for different sequences are now also a function of where their tokens end up in the buffer. The reason is that the
I expect similar effect to be observed on the CPU - maybe it's harder to encounter
Yes, these are normal to be different |
Got it. Thanks for the details. What do you think about using a mechanism other than floating point equality in the beam search code I linked to above, since floating point equality is no longer expected? The easiest and simplest approach is checking whether the sequence token ids are identical. That's O(n), but small n and the tests are cheap. More complicated would be to keep a running hash of token ids. I implemented the O(n) approach while investigating this; it was straightforward. |
Yes, sounds like a good idea |
As of ggerganov#1486, whisper.cpp uses a unified KV cache with KQ masking. As a result, depending on their location in the batch, identical sequences in a batch can have slightly different outputs due to floating point rounding errors during reduction. See the discussion in ggerganov#1941 for more details. The beam search code used "has identical sum of log probabilities" as a shorthand for "is an identical token sequence". However, per above, identical tokens do not necessarily result in identical probabilities. Instead, explicitly compare on sequences. This is linear in cost when they are identical, but the lengths are always small and the comparisons are cheap. This increases diversity during beam search. This improves output quality for some short samples I've been working with, at no detectable performance cost. I haven't checked against larger corpuses. Fixes ggerganov#1941
As of #1486, whisper.cpp uses a unified KV cache with KQ masking. As a result, depending on their location in the batch, identical sequences in a batch can have slightly different outputs due to floating point rounding errors during reduction. See the discussion in #1941 for more details. The beam search code used "has identical sum of log probabilities" as a shorthand for "is an identical token sequence". However, per above, identical tokens do not necessarily result in identical probabilities. Instead, explicitly compare on sequences. This is linear in cost when they are identical, but the lengths are always small and the comparisons are cheap. This increases diversity during beam search. This improves output quality for some short samples I've been working with, at no detectable performance cost. I haven't checked against larger corpuses. Fixes #1941
All else being otherwise equal, this encourages the beam candidate selection to re-use the same decoder, which slightly reduces the cache size. I wouldn't expect it to make much of a performance difference, but it helps when debug printing the cache and beam. Added as part of understanding ggerganov#1941.
As of ggerganov#1486, whisper.cpp uses a unified KV cache with KQ masking. As a result, depending on their location in the batch, identical sequences in a batch can have slightly different outputs due to floating point rounding errors during reduction. See the discussion in ggerganov#1941 for more details. The beam search code used "has identical sum of log probabilities" as a shorthand for "is an identical token sequence". However, per above, identical tokens do not necessarily result in identical probabilities. Instead, explicitly compare on sequences. This is linear in cost when they are identical, but the lengths are always small and the comparisons are cheap. This increases diversity during beam search. This improves output quality for some short samples I've been working with, at no detectable performance cost. I haven't checked against larger corpuses. Fixes ggerganov#1941
All else being otherwise equal, this encourages the beam candidate selection to re-use the same decoder, which slightly reduces the cache size. I wouldn't expect it to make much of a performance difference, but it helps when debug printing the cache and beam. Added as part of understanding ggerganov#1941.
As of ggerganov#1486, whisper.cpp uses a unified KV cache with KQ masking. As a result, depending on their location in the batch, identical sequences in a batch can have slightly different outputs due to floating point rounding errors during reduction. See the discussion in ggerganov#1941 for more details. The beam search code used "has identical sum of log probabilities" as a shorthand for "is an identical token sequence". However, per above, identical tokens do not necessarily result in identical probabilities. Instead, explicitly compare on sequences. This is linear in cost when they are identical, but the lengths are always small and the comparisons are cheap. This increases diversity during beam search. This improves output quality for some short samples I've been working with, at no detectable performance cost. I haven't checked against larger corpuses. Fixes ggerganov#1941
All else being otherwise equal, this encourages the beam candidate selection to re-use the same decoder, which slightly reduces the cache size. I wouldn't expect it to make much of a performance difference, but it helps when debug printing the cache and beam. Added as part of understanding ggerganov#1941.
As of ggerganov#1486, whisper.cpp uses a unified KV cache with KQ masking. As a result, depending on their location in the batch, identical sequences in a batch can have slightly different outputs due to floating point rounding errors during reduction. See the discussion in ggerganov#1941 for more details. The beam search code used "has identical sum of log probabilities" as a shorthand for "is an identical token sequence". However, per above, identical tokens do not necessarily result in identical probabilities. Instead, explicitly compare on sequences. This is linear in cost when they are identical, but the lengths are always small and the comparisons are cheap. This increases diversity during beam search. This improves output quality for some short samples I've been working with, at no detectable performance cost. I haven't checked against larger corpuses. Fixes ggerganov#1941
All else being otherwise equal, this encourages the beam candidate selection to re-use the same decoder, which slightly reduces the cache size. I wouldn't expect it to make much of a performance difference, but it helps when debug printing the cache and beam. Added as part of understanding ggerganov#1941.
As of ggerganov#1486, whisper.cpp uses a unified KV cache with KQ masking. As a result, depending on their location in the batch, identical sequences in a batch can have slightly different outputs due to floating point rounding errors during reduction. See the discussion in ggerganov#1941 for more details. The beam search code used "has identical sum of log probabilities" as a shorthand for "is an identical token sequence". However, per above, identical tokens do not necessarily result in identical probabilities. Instead, explicitly compare on sequences. This is linear in cost when they are identical, but the lengths are always small and the comparisons are cheap. This increases diversity during beam search. This improves output quality for some short samples I've been working with, at no detectable performance cost. I haven't checked against larger corpuses. Fixes ggerganov#1941
Summary
Given a beam search that contains two decoders with identical inputs (encoded state, past tokens), the same token is getting slightly different logits across those two beam searches.
This matters because (a) it hints that there's a bug somewhere and (b) the beam search de-dup logic assumes that logprobs will be identical for identical sequences.
It happens with metal, but not with cpu-only, suggesting a bug in the metal graph evaluation.
Reproduce
#define WHISPER_DEBUG
Run:
./main -m models/ggml-large-v2.bin sid5s.wav
Result:
The lines of interest are
Observe that
plog
andsum_logprobs
are slightly different. They should be identical; the sequences leading up to them are identical ([_SOT_], [_LANG_en], [_TRANSCRIBE_], [_BEG_]
).It does not reproduce when running with cpu only:
./main -m models/ggml-large-v2.bin -ng sid5s.wav
Relevant lines are:
The absolute values also differ across metal and cpu; I assume that this is expected?
The text was updated successfully, but these errors were encountered: