Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gemma2: add sliding window mask #8227

Merged
merged 10 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2369,6 +2369,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_final_logit_softcapping(
self.hparams["final_logit_softcapping"]
)
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unusem
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class Attention:
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
SLIDING_WINDOW = "{arch}.attention.sliding_window"

class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,9 @@ def add_kv_lora_rank(self, length: int) -> None:
def add_relative_attn_buckets_count(self, value: int) -> None:
self.add_uint32(Keys.Attention.REL_BUCKETS_COUNT.format(arch=self.arch), value)

def add_sliding_window(self, value: int) -> None:
self.add_uint32(Keys.Attention.SLIDING_WINDOW.format(arch=self.arch), value)

def add_pooling_type(self, value: PoolingType) -> None:
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)

Expand Down
53 changes: 44 additions & 9 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ enum llm_kv {
LLM_KV_ATTENTION_Q_LORA_RANK,
LLM_KV_ATTENTION_KV_LORA_RANK,
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
LLM_KV_ATTENTION_SLIDING_WINDOW,

LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE,
Expand Down Expand Up @@ -409,6 +410,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },

{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
Expand Down Expand Up @@ -2099,6 +2101,7 @@ struct llama_hparams {
uint32_t n_ff_shexp = 0;
uint32_t n_expert_shared = 0;
float expert_weights_scale = 0.0;
uint32_t n_sliding = 0; // sliding window attention (SWA)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
uint32_t n_sliding = 0; // sliding window attention (SWA)
uint32_t n_swa = 0; // sliding window attention (SWA)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed in ed5496f


float f_norm_eps;
float f_norm_rms_eps;
Expand Down Expand Up @@ -2661,6 +2664,9 @@ struct llama_context {
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]

// KQ mask per layer, used by sliding window attention (gemma 2)
struct ggml_tensor * inp_KQ_mask_SWA;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct ggml_tensor * inp_KQ_mask_SWA;
struct ggml_tensor * inp_KQ_mask_swa;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed in ed5496f


// control vectors
struct llama_control_vector cvec;
};
Expand Down Expand Up @@ -4709,6 +4715,8 @@ static void llm_load_hparams(
} break;
case LLM_ARCH_GEMMA2:
{
hparams.n_sliding = 4096; // default value of gemma 2
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_sliding, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
Expand Down Expand Up @@ -7786,6 +7794,7 @@ struct llm_build_context {
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
lctx.inp_KQ_mask_SWA = nullptr;
}

void free() {
Expand Down Expand Up @@ -7938,15 +7947,18 @@ struct llm_build_context {
return lctx.inp_out_ids;
}

struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
if (causal) {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
struct ggml_tensor * build_inp_KQ_mask(bool causal = true, bool sliding_window = false) {
struct ggml_tensor * KQ_mask = causal
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
cb(KQ_mask, "KQ_mask", -1);
ggml_set_input(KQ_mask);
if (sliding_window) {
lctx.inp_KQ_mask_SWA = KQ_mask;
} else {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
lctx.inp_KQ_mask = KQ_mask;
}
cb(lctx.inp_KQ_mask, "KQ_mask", -1);
ggml_set_input(lctx.inp_KQ_mask);
return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
return flash_attn ? ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16) : KQ_mask;
}

struct ggml_tensor * build_inp_mean() {
Expand Down Expand Up @@ -11029,9 +11041,14 @@ struct llm_build_context {
struct ggml_tensor * inp_pos = build_inp_pos();

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
// gemma 2 requires different mask for layers using sliding window (SWA)
struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(true, false);
struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(true, true);

for (int il = 0; il < n_layer; ++il) {
// (il % 2) layers use SWA
struct ggml_tensor * KQ_mask = (il % 2 == 0) ? KQ_mask_SWA : KQ_mask_full;

// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
Expand Down Expand Up @@ -12670,7 +12687,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));

float * data = (float *) lctx.inp_KQ_mask->data;
float * data = (float *) lctx.inp_KQ_mask->data;
float * data_swa = nullptr;
const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the meaning of n_keep_swa. Seems this won't work with batches of multiple sequences

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm not sure if I'm doing it correctly: It is to emulate the rolling. If we input n_tokens then we only keep n_sliding - n_tokens tokens in cache, so the total number of tokens for attention is n_tokens plus n_sliding - n_tokens equals n_sliding

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me just restricting the position delta to be less than n_swa is enough:

diff --git a/src/llama.cpp b/src/llama.cpp
index 71b7ef62..fa207234 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -12722,7 +12722,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
 
                         // may need to cut off old tokens for sliding window
                         if (data_swa) {
-                            if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) {
+                            if (pos - lctx.kv_self.cells[i].pos >= hparams.n_sliding) {
                                 f = -INFINITY;
                             }
                             data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;

This way, in SWA layers, the token with position 4096 does not "see" the token with position 0, but does "see" the token at position 1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK thanks, that's clear for me now. I changed this code in ed5496f


if (lctx.model.arch == LLM_ARCH_GEMMA2) {
GGML_ASSERT(lctx.inp_KQ_mask_SWA);
GGML_ASSERT(hparams.n_sliding > 0);
data = (float *) lctx.inp_KQ_mask->data;
data_swa = (float *) lctx.inp_KQ_mask_SWA->data;
// because layer masks are alternate for gemma 2, we only need to take first 2 layers
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simplified a bit.

Suggested change
if (lctx.model.arch == LLM_ARCH_GEMMA2) {
GGML_ASSERT(lctx.inp_KQ_mask_SWA);
GGML_ASSERT(hparams.n_sliding > 0);
data = (float *) lctx.inp_KQ_mask->data;
data_swa = (float *) lctx.inp_KQ_mask_SWA->data;
// because layer masks are alternate for gemma 2, we only need to take first 2 layers
}
if (lctx.inp_KQ_mask_SWA) {
data_swa = (float *) lctx.inp_KQ_mask_SWA->data;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am not mistaken, mistral uses SWA every layer. So maybe this needs to be separated to allow having only inp_KQ_mask_SWA? Will the same implementation work?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just looked at mistral reference implementation, they seem to use different mask for each layer. Link: https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/cache.py

So I think my previous version (using std::vector) can handle that. Do you think I should revert the change?

It surprises me a bit, since mistral's quality doesn't seem to degrade even it's missing SWA (or it only breaks after 4096 tokens?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have been looking at this code for a while and reviewing the mistral paper, and I think this is an implementation of the rolling buffer cache rather than sliding window attention. As far as I can tell, mistral has the same sliding window of 4096 tokens on each layer. Knowing that, it is possible to reduce the size of the KV cache to the sliding window size, but that requires some additional housekeeping so that eg. the rope still receives the absolute positions of the tokens, but the data is actually stored in the position pos % sliding_window. But maybe I am misunderstanding something, can you point me to the specific code?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should be possible. The thing I cannot figure out is how to avoid calling llama_kv_cache_find_slot() per-layer - seems it would be a big waste to do it like this, although it would generalize to support arbitrary KV cache layer sizes

Copy link
Collaborator Author

@ngxson ngxson Jul 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I assume the code is reference implementation so not very good quality. Having rolling buffer would be ideal for llama.cpp, but seems like too many changes. This is mostly to answer your question earlier: Will the same implementation work? Yes it works with different sliding window mask per layer, but will be waste of memory without rolling buffer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would the mask differ in each layer? My understanding is that the mask would be the same for all the layers, and it relies on the fact that the states in the KV cache depend on all the previous tokens to be able to access information beyond the sliding window.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked deeper into the paper, seems like I missed something.

Looking at this figure:

image

And the explanation:

image

I'd assume that the mask for each layer is shifted by the size of window - 1, for example:

  • layer 0: 0, 0, 0, 1, 1
  • layer 1: 0, 0, 1, 1, 0
  • layer 2: 0, 1, 1, 0, 0
  • ...

But then what I don't understand is the phrase "position i of the layer k, hi, attends to all hidden states from
the previous layer with positions between i − W and i". On the surface, it seems to explain how layer 1 knows about the tokens fall outside of its window (which is in layer 0), but then what's not clear to me is how one layer can attend to the previous one.

Also looking at the HF implementation code, seems like there is no such thing. They just add same attention mask for all layers: https://github.com/huggingface/transformers/blob/e65502951593a76844e872fee9c56b805598538a/src/transformers/models/mistral/modeling_mistral.py#L354

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be simplified a bit.

Changed in ed5496f

I think for now we can keep the implementation this way, I'll need more time to figure out how mistral actually use SWA.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then what I don't understand is the phrase "position i of the layer k, hi, attends to all hidden states from the previous layer with positions between i − W and i". On the surface, it seems to explain how layer 1 knows about the tokens fall outside of its window (which is in layer 0), but then what's not clear to me is how one layer can attend to the previous one.

I think it doesn't directly "attend" to the tokens from the previous one. It just receives information about those tokens through the output of previous layer.

I am also trying to understand this concept from the past 3 days. I did not pay attention to this when Mistral v1 was released and I remember seeing that Mistral v2 removed SWA.


// For causal attention, use only the previous KV cells
// of the correct sequence for each token of the batch.
Expand All @@ -12692,6 +12719,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;

// may need to cut off old tokens for sliding window
if (data_swa) {
if (pos - lctx.kv_self.cells[i].pos > n_keep_swa) {
f = -INFINITY;
}
data_swa[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
}
}

Expand Down
Loading