generated from amitmerchant1990/reverie
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f1e1b69
commit ee068ad
Showing
3 changed files
with
117 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
--- | ||
layout: post | ||
title: "Visual Autoregressive Modeling: Scalable Image Generation via Next-Scale Prediction" | ||
categories: [] | ||
year: 2024 | ||
type: paper | ||
--- | ||
|
||
DeepSeek V3 is an astonashing feat of engineering, model performance aside (the topic seems contentious, i've heard mixed opinions so far, apparently it's gets stuck in a reasoning spiral sometimes), being able to train a model of this capacity on a 2048 H800 cluster.... in just 4 GPU days? Crazy. They only spent $5M dollars training this thing. From what I hear the DeepSeek team is barely >100 people, and it's all in-house Chinese talent. I mean if you were worried about China's development before... yeah I don't know, they're fucking good. | ||
|
||
despite what it may seem like however, this post is not *only* to rave about how good the DeepSeek team is, but rather to provide a formal introduction to one of DeepSeeks own inventions: Multi-head Latent Attention; MLA. MLA was first introduced in DeepSeek V2, back in spring earlier this year i believe? I spent some time diging into it back then, but unfortunately my mind is slipping on the details so i'm going to take another pass at it and you can come along with me. i'll be looking at it through a the perspective of the historical evolution of attention mechanisms: MHA -> MQA -> GQA -> MLA. The content below is heavily inspired (a considerable amount os directly translated) by a post from Jianlin Su, the author of RoPE, who runs an incredible [blog](https://kexue.fm/), in chinese. | ||
|
||
### MHA | ||
|
||
Multi-head attention is the traditional attention mechanism defined in *Attention is all you need*. Suppose the input sequence consists of row vectors $x_1, x_2,...,x_l$ where $x_i \in \mathbb{R}^{d}$, then MHA is formally represented as: | ||
|
||
$$ | ||
\begin{aligned} | ||
\mathbf{o}_t &= [\mathbf{o}_t^{(1)}, \mathbf{o}_t^{(2)}, \dots, \mathbf{o}_t^{(h)}] \\ | ||
\mathbf{o}_t^{(s)} &= \text{Attention}(\mathbf{q}_t^{(s)}, \mathbf{k}_{\leq t}^{(s)}, \mathbf{v}_{\leq t}^{(s)}) | ||
\triangleq \frac{\sum_{i \leq t} \exp(\mathbf{q}_t^{(s)} \cdot \mathbf{k}_i^{(s)\top}) \mathbf{v}_i^{(s)}}{\sum_{i \leq t} \exp(\mathbf{q}_t^{(s)} \cdot \mathbf{k}_i^{(s)\top})} \\ | ||
\mathbf{q}_i^{(s)} &= \mathbf{x}_i \mathbf{W}_q^{(s)} \in \mathbb{R}^{d_k}, \quad \mathbf{W}_q^{(s)} \in \mathbb{R}^{d \times d_k} \\ | ||
\mathbf{k}_i^{(s)} &= \mathbf{x}_i \mathbf{W}_k^{(s)} \in \mathbb{R}^{d_k}, \quad \mathbf{W}_k^{(s)} \in \mathbb{R}^{d \times d_k} \\ | ||
\mathbf{v}_i^{(s)} &= \mathbf{x}_i \mathbf{W}_v^{(s)} \in \mathbb{R}^{d_v}, \quad \mathbf{W}_v^{(s)} \in \mathbb{R}^{d \times d_v} | ||
\end{aligned} | ||
$$ | ||
|
||
An example configuration (Llama 3.1 70B) of the above parameters is $d = 8192, d_k = 128, h=64$. Note that $d_k = d / h$ is common practice. | ||
|
||
During **inference** a causal autoregressive language model generates tokens recursively, meaning the generation of token $t + 1$ does not affect the previously computed matrices $\mathbf{k}_{≤t}^{(s)}, \mathbf{v}_{≤t}^{(s)}$. These matrices can be cached in a KV cache to reduce redundant computation, trading compute for memory. However the KV cache grows with both the model size and input length. At sufficiently long context lengths, the KV cache can consume the majority of GPU memory, often surpassing the memory required for model parameters and activations (albeit flash attention and other low level optimizations have aleviated the issue). This scaling issue makes it a bottleneck for efficient inference, especially for models serving long inputs. | ||
|
||
![](/images/kvcache.png) | ||
|
||
A solution would be to deploy such models across multiple cards, or when necessary across multiple machines. However, a guiding principle when deploying models across a GPU cluster is that intra-card communication bandwidth > inter-card communication bandwidth > inter-machine communication bandwidth. The more devices a deployment spans, the higher the communication overhead + cost becomes. Thus, we aim to minimize the KV cache such that we can serve long context models on as few GPUs as possible, with the ultimate goal of lowering inference costs. | ||
|
||
This provides the guiding motivation behind the subsequent developments to the attention mechanism. | ||
|
||
### MQA | ||
|
||
Multi-query attention (MQA) is the extreme alternative to MHA, published in the 2019 paper [*Fast Transformer Decoding: One Write-Head is All You Need*](https://arxiv.org/abs/1911.02150) it represents the cautionary reaction to the apparent problems of the KV Cache. If one understands MHA, understanding MQA is simple: let all attention heads share the same key and values. Formally, this means canceling the superscripts of all $k, v$ in MHA: | ||
|
||
$$ | ||
\begin{aligned} | ||
\mathbf{o}_t &= [\mathbf{o}_t^{(1)}, \mathbf{o}_t^{(2)}, \dots, \mathbf{o}_t^{(h)}] \\ | ||
\mathbf{o}_t^{(s)} &= \text{Attention}(\mathbf{q}_t^{(s)}, \mathbf{k}_{\leq t}^{\cancel{(s)}}, \mathbf{v}_{\leq t}^{\cancel{(s)}}) | ||
\triangleq \frac{\sum_{i \leq t} \exp(\mathbf{q}_t^{(s)} \mathbf{k}_i^{\cancel{(s)}\top}) \mathbf{v}_i^{\cancel{(s)}}}{\sum_{i \leq t} \exp(\mathbf{q}_t^{(s)} \mathbf{k}_i^{\cancel{(s)}\top})} \\ | ||
\mathbf{q}_t^{(s)} &= \mathbf{x}_t \mathbf{W}_q^{(s)} \in \mathbb{R}^{d_k}, \quad \mathbf{W}_q^{(s)} \in \mathbb{R}^{d \times d_k} \\ | ||
\mathbf{k}_i^{\cancel{(s)}} &= \mathbf{x}_i \mathbf{W}_k^{\cancel{(s)}} \in \mathbb{R}^{d_k}, \quad \mathbf{W}_k^{\cancel{(s)}} \in \mathbb{R}^{d \times d_k} \\ | ||
\mathbf{v}_i^{\cancel{(s)}} &= \mathbf{x}_i \mathbf{W}_v^{\cancel{(s)}} \in \mathbb{R}^{d_v}, \quad \mathbf{W}_v^{\cancel{(s)}} \in \mathbb{R}^{d \times d_v} | ||
\end{aligned} | ||
$$ | ||
|
||
In practice, the $k, v$ heads are broadcast in-place across $q$ heads during computation. This reduces the KV Cache to $1 / h$ of the original size, which is a significant reduction. It does however suffer in performance, but MQA supports claim this can be offset by increased training time. The "saved" parameters can be shifted to the FFN to make up for some of the lost performance. | ||
|
||
### GQA | ||
|
||
Grouped Query Attention is the generalized version of MHA and MQA, published in the 2022 paper [*GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints*](https://arxiv.org/abs/2305.13245). GQA divides the KV heads into $g$ groups (where $g$ evenly divides $h$), where each group is paired with 1 or more query heads. Formally, this is expressed as: | ||
|
||
$$ | ||
\begin{aligned} | ||
\mathbf{o}_t &= [\mathbf{o}_t^{(1)}, \mathbf{o}_t^{(2)}, \dots, \mathbf{o}_t^{(h)}] \\ | ||
\mathbf{o}_t^{(s)} &= \text{Attention}\left(\mathbf{q}_t^{(s)}, \mathbf{k}_{\leq t}^{\left(\lceil sg/h \rceil\right)}, \mathbf{v}_{\leq t}^{\left(\lceil sg/h \rceil\right)}\right) \\ | ||
&\triangleq | ||
\frac{\sum_{i \leq t} \exp\left(\mathbf{q}_t^{(s)} \mathbf{k}_i^{\left(\lceil sg/h \rceil\right)\top}\right) \mathbf{v}_i^{\left(\lceil sg/h \rceil\right)}}{\sum_{i \leq t} \exp\left(\mathbf{q}_t^{(s)} \mathbf{k}_i^{\left(\lceil sg/h \rceil\right)\top}\right)} \\ | ||
\mathbf{q}_t^{(s)} &= \mathbf{x}_t \mathbf{W}_q^{(s)} \in \mathbb{R}^{d_k}, \quad \mathbf{W}_q^{(s)} \in \mathbb{R}^{d \times d_k} \\ | ||
\mathbf{k}_i^{\left(\lceil sg/h \rceil\right)} &= \mathbf{x}_i \mathbf{W}_k^{\left(\lceil sg/h \rceil\right)} \in \mathbb{R}^{d_k}, \quad \mathbf{W}_k^{\left(\lceil sg/h \rceil\right)} \in \mathbb{R}^{d \times d_k} \\ | ||
\mathbf{v}_i^{\left(\lceil sg/h \rceil\right)} &= \mathbf{x}_i \mathbf{W}_v^{\left(\lceil sg/h \rceil\right)} \in \mathbb{R}^{d_v}, \quad \mathbf{W}_v^{\left(\lceil sg/h \rceil\right)} \in \mathbb{R}^{d \times d_v} | ||
\end{aligned} | ||
$$ | ||
|
||
|
||
GQA generalizes MHA and MQA by varying the number of attention groups $g$. When $g = h$ it replicates MHA; when $g = 1$ it corresponds to MQA; and for $1 < g < h$, it compresses the KV cache by a factor of $g / h$. This flexibility makes GQA a more versatile and efficient implementation, as it allows precise control over the trade-off between compression and computational cost. | ||
|
||
An important advantage of GQA is its inherent support for parallelism in attention computation. In large models, where a single GPU is insufficient to store the full model, attention computations can be parallelized across heads, which are independently processed before concatenation (see formulas above). By selecting $g$ to align with the number of GPUs used for parallelization, GQA minimizes inter-device communication overhead, enhancing scalability and efficiency. | ||
|
||
### MLA | ||
|
||
Now that we've got the groundwork of MHA, MQA, and GQA done, we're ready to tackle Multi-head Latent Attention. At first glance, MLA introduces a low-rank projection of the KV Cache, to which a reader may question "Why did it take so long until someone proposed a low rank decomposition of the KV Cache considering how long LoRA has been around?" | ||
|
||
However, consider what happens in GQA when we stack all $K, V$ together: | ||
|
||
|
||
$$ | ||
\begin{aligned} | ||
\underbrace{\left[\mathbf{k}_i^{(1)}, \dots, \mathbf{k}_i^{(g)}, \mathbf{v}_i^{(1)}, \dots, \mathbf{v}_i^{(g)}\right]}_{\mathbf{c}_i \in \mathbb{R}^{g(d_k + d_v)}} | ||
&= \mathbf{x}_i | ||
\underbrace{\left[\mathbf{W}_k^{(1)}, \dots, \mathbf{W}_k^{(g)}, \mathbf{W}_v^{(1)}, \dots, \mathbf{W}_v^{(g)}\right]}_{\mathbf{W}_c \in \mathbb{R}^{d \times g(d_k + d_v)}} | ||
\end{aligned} | ||
$$ | ||
|
||
If we consider $c_i$ to represent the concatenated $k, v$, and the corresponding projection matrices as $W_c$ we see that GQA is already performing a low-rank projection. Generally, we have that $d_c = g(d_k + d_v) < d$, so the transformation from $x_i$ to $c_i$ is a low-rank projection. As such, the contribution of MLA is not the low rank projection itself, but rather what happens after the projection. | ||
|
||
#### Part 1 | ||
|
||
GQA downprojects the $x_i$ into $2 \times h \times g$, splits the matrice into two halves for $K$ and $V$, then further divides this into $g$ parts and replicates each part $h \ g$ times to "make up" the $K$ and $V$ required for the $h$ heads. While effective, this approach imposes structural rigidity by enforcing a fixed grouping and replication scheme, potentially limiting the representational capacity of $K$ and $V$. MLA instead seeks a more expressive representation by replacing GQA's splitting and replication with a learned linear transformation. This transformation projects $x_i$ into a shared latent space, capturing essential features in a compressed form. | ||
|
||
$$ | ||
c_i = x_i W_c \in \mathbb{R}^{d_c}, \quad W_c \in \mathbb{R}^{d \times d_c}. | ||
$$ | ||
|
||
Once $c_i$ is derived, it serves as the basis for generating head-specific keys and values. For each attention head $s$, a linear transformation is applied to map $c_i$ into the full query space $\mathbb{R}^{d_k}$: | ||
|
||
$$ | ||
k_i^{(s)} = c_i W_k^{(s)} \in \mathbb{R}^{d_k}, \quad W_k^{(s)} \in \mathbb{R}^{d_c \times d_k} | ||
$$ | ||
|
||
$$ | ||
v_i^{(s)} = c_i W_v^{(s)} \in \mathbb{R}^{d_v}, \quad W_v^{(s)} \in \mathbb{R}^{d_c \times d_v}. | ||
$$ | ||
|
||
Theoretically, this increases model capacity, but the goal of GQA is to reduce KV Cache, so what happens to our cache? In GQA, we would cache our downprojected $k_i, v_i$, however, MLA's approach recreates all $h$ KV heads, causing the KV Cache size to revert to that of MHA? Interestingly, the authors leave this be during training, but then circumvent this issue during inference by caching only $c_i$ and fusing the projection matrices $W_k, W_v$ with subsequent operations. Notably, as $c_i$ is independant of $s$, meaning that it is shared across all heads, MLA transforms into MQA during inference. | ||
|
||
#### Part 2 | ||
|
||
|
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.