-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
gemma2: add sliding window mask #8227
Changes from 5 commits
7df7530
ab2c3de
46b56e6
231dae4
d09ecb8
ed5496f
ce711f6
7dc9cbf
80bdc38
e24328e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -317,6 +317,7 @@ enum llm_kv { | |||||||||||||||||||||
LLM_KV_ATTENTION_Q_LORA_RANK, | ||||||||||||||||||||||
LLM_KV_ATTENTION_KV_LORA_RANK, | ||||||||||||||||||||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, | ||||||||||||||||||||||
LLM_KV_ATTENTION_SLIDING_WINDOW, | ||||||||||||||||||||||
|
||||||||||||||||||||||
LLM_KV_ROPE_DIMENSION_COUNT, | ||||||||||||||||||||||
LLM_KV_ROPE_FREQ_BASE, | ||||||||||||||||||||||
|
@@ -409,6 +410,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { | |||||||||||||||||||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, | ||||||||||||||||||||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, | ||||||||||||||||||||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" }, | ||||||||||||||||||||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, | ||||||||||||||||||||||
|
||||||||||||||||||||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, | ||||||||||||||||||||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, | ||||||||||||||||||||||
|
@@ -2099,6 +2101,7 @@ struct llama_hparams { | |||||||||||||||||||||
uint32_t n_ff_shexp = 0; | ||||||||||||||||||||||
uint32_t n_expert_shared = 0; | ||||||||||||||||||||||
float expert_weights_scale = 0.0; | ||||||||||||||||||||||
uint32_t n_sliding = 0; // sliding window attention (SWA) | ||||||||||||||||||||||
|
||||||||||||||||||||||
float f_norm_eps; | ||||||||||||||||||||||
float f_norm_rms_eps; | ||||||||||||||||||||||
|
@@ -2661,6 +2664,9 @@ struct llama_context { | |||||||||||||||||||||
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] | ||||||||||||||||||||||
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] | ||||||||||||||||||||||
|
||||||||||||||||||||||
// KQ mask per layer, used by sliding window attention (gemma 2) | ||||||||||||||||||||||
struct ggml_tensor * inp_KQ_mask_SWA; | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed in ed5496f |
||||||||||||||||||||||
|
||||||||||||||||||||||
// control vectors | ||||||||||||||||||||||
struct llama_control_vector cvec; | ||||||||||||||||||||||
}; | ||||||||||||||||||||||
|
@@ -4709,6 +4715,8 @@ static void llm_load_hparams( | |||||||||||||||||||||
} break; | ||||||||||||||||||||||
case LLM_ARCH_GEMMA2: | ||||||||||||||||||||||
{ | ||||||||||||||||||||||
hparams.n_sliding = 4096; // default value of gemma 2 | ||||||||||||||||||||||
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_sliding, false); | ||||||||||||||||||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); | ||||||||||||||||||||||
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); | ||||||||||||||||||||||
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); | ||||||||||||||||||||||
|
@@ -7786,6 +7794,7 @@ struct llm_build_context { | |||||||||||||||||||||
lctx.inp_s_copy = nullptr; | ||||||||||||||||||||||
lctx.inp_s_mask = nullptr; | ||||||||||||||||||||||
lctx.inp_s_seq = nullptr; | ||||||||||||||||||||||
lctx.inp_KQ_mask_SWA = nullptr; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
void free() { | ||||||||||||||||||||||
|
@@ -7938,15 +7947,18 @@ struct llm_build_context { | |||||||||||||||||||||
return lctx.inp_out_ids; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { | ||||||||||||||||||||||
if (causal) { | ||||||||||||||||||||||
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); | ||||||||||||||||||||||
struct ggml_tensor * build_inp_KQ_mask(bool causal = true, bool sliding_window = false) { | ||||||||||||||||||||||
struct ggml_tensor * KQ_mask = causal | ||||||||||||||||||||||
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) | ||||||||||||||||||||||
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); | ||||||||||||||||||||||
cb(KQ_mask, "KQ_mask", -1); | ||||||||||||||||||||||
ggml_set_input(KQ_mask); | ||||||||||||||||||||||
if (sliding_window) { | ||||||||||||||||||||||
lctx.inp_KQ_mask_SWA = KQ_mask; | ||||||||||||||||||||||
} else { | ||||||||||||||||||||||
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); | ||||||||||||||||||||||
lctx.inp_KQ_mask = KQ_mask; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
cb(lctx.inp_KQ_mask, "KQ_mask", -1); | ||||||||||||||||||||||
ggml_set_input(lctx.inp_KQ_mask); | ||||||||||||||||||||||
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; | ||||||||||||||||||||||
return flash_attn ? ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16) : KQ_mask; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
struct ggml_tensor * build_inp_mean() { | ||||||||||||||||||||||
|
@@ -11029,9 +11041,14 @@ struct llm_build_context { | |||||||||||||||||||||
struct ggml_tensor * inp_pos = build_inp_pos(); | ||||||||||||||||||||||
|
||||||||||||||||||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads) | ||||||||||||||||||||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); | ||||||||||||||||||||||
// gemma 2 requires different mask for layers using sliding window (SWA) | ||||||||||||||||||||||
struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(true, false); | ||||||||||||||||||||||
struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(true, true); | ||||||||||||||||||||||
|
||||||||||||||||||||||
for (int il = 0; il < n_layer; ++il) { | ||||||||||||||||||||||
// (il % 2) layers use SWA | ||||||||||||||||||||||
struct ggml_tensor * KQ_mask = (il % 2 == 0) ? KQ_mask_SWA : KQ_mask_full; | ||||||||||||||||||||||
|
||||||||||||||||||||||
// norm | ||||||||||||||||||||||
cur = llm_build_norm(ctx0, inpL, hparams, | ||||||||||||||||||||||
model.layers[il].attn_norm, NULL, | ||||||||||||||||||||||
|
@@ -12670,7 +12687,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | |||||||||||||||||||||
|
||||||||||||||||||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); | ||||||||||||||||||||||
|
||||||||||||||||||||||
float * data = (float *) lctx.inp_KQ_mask->data; | ||||||||||||||||||||||
float * data = (float *) lctx.inp_KQ_mask->data; | ||||||||||||||||||||||
float * data_swa = nullptr; | ||||||||||||||||||||||
const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens; | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand the meaning of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I'm not sure if I'm doing it correctly: It is to emulate the rolling. If we input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems to me just restricting the position delta to be less than diff --git a/src/llama.cpp b/src/llama.cpp
index 71b7ef62..fa207234 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -12722,7 +12722,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
// may need to cut off old tokens for sliding window
if (data_swa) {
- if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) {
+ if (pos - lctx.kv_self.cells[i].pos >= hparams.n_sliding) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; This way, in SWA layers, the token with position 4096 does not "see" the token with position 0, but does "see" the token at position 1. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK thanks, that's clear for me now. I changed this code in ed5496f |
||||||||||||||||||||||
|
||||||||||||||||||||||
if (lctx.model.arch == LLM_ARCH_GEMMA2) { | ||||||||||||||||||||||
GGML_ASSERT(lctx.inp_KQ_mask_SWA); | ||||||||||||||||||||||
GGML_ASSERT(hparams.n_sliding > 0); | ||||||||||||||||||||||
data = (float *) lctx.inp_KQ_mask->data; | ||||||||||||||||||||||
data_swa = (float *) lctx.inp_KQ_mask_SWA->data; | ||||||||||||||||||||||
// because layer masks are alternate for gemma 2, we only need to take first 2 layers | ||||||||||||||||||||||
} | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be simplified a bit.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I am not mistaken, mistral uses SWA every layer. So maybe this needs to be separated to allow having only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've just looked at mistral reference implementation, they seem to use different mask for each layer. Link: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/cache.py So I think my previous version (using It surprises me a bit, since mistral's quality doesn't seem to degrade even it's missing SWA (or it only breaks after 4096 tokens?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have been looking at this code for a while and reviewing the mistral paper, and I think this is an implementation of the rolling buffer cache rather than sliding window attention. As far as I can tell, mistral has the same sliding window of 4096 tokens on each layer. Knowing that, it is possible to reduce the size of the KV cache to the sliding window size, but that requires some additional housekeeping so that eg. the rope still receives the absolute positions of the tokens, but the data is actually stored in the position There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it should be possible. The thing I cannot figure out is how to avoid calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I assume the code is reference implementation so not very good quality. Having rolling buffer would be ideal for llama.cpp, but seems like too many changes. This is mostly to answer your question earlier: Will the same implementation work? Yes it works with different sliding window mask per layer, but will be waste of memory without rolling buffer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How would the mask differ in each layer? My understanding is that the mask would be the same for all the layers, and it relies on the fact that the states in the KV cache depend on all the previous tokens to be able to access information beyond the sliding window. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I looked deeper into the paper, seems like I missed something. Looking at this figure: And the explanation: I'd assume that the mask for each layer is shifted by the size of
But then what I don't understand is the phrase "position i of the layer k, hi, attends to all hidden states from Also looking at the HF implementation code, seems like there is no such thing. They just add same attention mask for all layers: https://github.com/huggingface/transformers/blob/e65502951593a76844e872fee9c56b805598538a/src/transformers/models/mistral/modeling_mistral.py#L354 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Changed in ed5496f I think for now we can keep the implementation this way, I'll need more time to figure out how mistral actually use SWA. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think it doesn't directly "attend" to the tokens from the previous one. It just receives information about those tokens through the output of previous layer. I am also trying to understand this concept from the past 3 days. I did not pay attention to this when Mistral v1 was released and I remember seeing that Mistral v2 removed SWA. |
||||||||||||||||||||||
|
||||||||||||||||||||||
// For causal attention, use only the previous KV cells | ||||||||||||||||||||||
// of the correct sequence for each token of the batch. | ||||||||||||||||||||||
|
@@ -12692,6 +12719,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | |||||||||||||||||||||
} | ||||||||||||||||||||||
} | ||||||||||||||||||||||
data[h*(n_kv*n_tokens) + j*n_kv + i] = f; | ||||||||||||||||||||||
|
||||||||||||||||||||||
// may need to cut off old tokens for sliding window | ||||||||||||||||||||||
if (data_swa) { | ||||||||||||||||||||||
if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) { | ||||||||||||||||||||||
f = -INFINITY; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f; | ||||||||||||||||||||||
} | ||||||||||||||||||||||
} | ||||||||||||||||||||||
} | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed in ed5496f