Skip to content

Commit

Permalink
Support for KV caching and batched inference
Browse files Browse the repository at this point in the history
  • Loading branch information
mseeger committed Feb 6, 2025
1 parent f6031e3 commit ff817a9
Show file tree
Hide file tree
Showing 30 changed files with 2,381 additions and 788 deletions.
65 changes: 50 additions & 15 deletions litgpt/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, List

import torch
import torch.nn as nn
Expand All @@ -19,6 +19,7 @@
from litgpt.model import GPT as BaseModel
from litgpt.model import Block as BaseBlock
from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention
from litgpt.kvcache.base import KVCache, KeysAndValues, DefaultKeysAndValues


@dataclass
Expand All @@ -29,20 +30,33 @@ class Config(BaseConfig):

class GPT(BaseModel):
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
def __init__(self, config: Config) -> None:
def __init__(
self,
config: Config,
kv_cache: Optional[List[KVCache]] = None
) -> None:
nn.Module.__init__(self)
assert config.padded_vocab_size is not None
self.config = config

if kv_cache is not None:
if len(kv_cache) != config.n_layer:
raise ValueError(f"kv_cache length {len(kv_cache)} != {config.n_layer} = config.n_layer")
for kvc in kv_cache:
self._check_kv_cache(config, kvc)
self._default_kv_cache = False
else:
kv_cache = [None] * config.n_layer
self._default_kv_cache = True
self.lm_head = nn.Linear(
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(
Block(config, block_idx)
for block_idx in range(config.n_layer)
Block(config, block_idx, kv_cache=kvc)
for block_idx, kvc in enumerate(kv_cache)
),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
Expand All @@ -62,17 +76,27 @@ def _init_weights(self, module: nn.Module) -> None:


class Block(BaseBlock):
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
self.attn = CausalSelfAttention(config, block_idx)
def __init__(
self,
config: Config,
block_idx: int,
kv_cache: Optional[KVCache] = None,
) -> None:
super().__init__(config, block_idx, kv_cache)
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)


class CausalSelfAttention(BaseCausalSelfAttention):
"""A modification of `litgpt.model.CausalSelfAttention` that adds the attention
over the adaption prompt."""

def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
def __init__(
self,
config: Config,
block_idx: int,
kv_cache: Optional[KVCache] = None,
) -> None:
super().__init__(config, block_idx, kv_cache)
if block_idx >= config.adapter_start_layer:
# adapter embedding layer
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
Expand All @@ -82,11 +106,16 @@ def __init__(self, config: Config, block_idx: int) -> None:
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None

def scaled_dot_product_attention(
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
y = super().scaled_dot_product_attention(q, k, v, mask)
self,
q: torch.Tensor,
k_and_v: KeysAndValues,
mask: Optional[torch.Tensor] = None,
is_causal: bool = True,
return_scores: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
y, scores = super().scaled_dot_product_attention(q, k_and_v, mask, is_causal, return_scores)
if self.block_idx < self.config.adapter_start_layer:
return y
return y, scores

aT = self.config.adapter_prompt_length
if self.adapter_kv_cache is not None:
Expand All @@ -110,8 +139,14 @@ def scaled_dot_product_attention(

T = q.size(2)
amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
ay = super().scaled_dot_product_attention(q, ak, av, amask)
return y + self.gating_factor * ay
a_k_and_v = DefaultKeysAndValues(keys=ak, values=av)
ay, _ = super().scaled_dot_product_attention(
q=q,
k_and_v=a_k_and_v,
mask=amask,
is_causal=False,
)
return y + self.gating_factor * ay, scores

def reset_parameters(self) -> None:
if hasattr(self, "gating_factor"):
Expand Down
42 changes: 33 additions & 9 deletions litgpt/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

from dataclasses import dataclass
from typing import Any, Dict, Type, Optional
from typing import Any, Dict, Type, Optional, List

import torch
import torch.nn as nn
Expand All @@ -22,6 +22,7 @@
from litgpt.adapter import Config as BaseConfig
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
from litgpt.utils import map_old_state_dict_weights
from litgpt.kvcache.base import KVCache


@dataclass
Expand Down Expand Up @@ -64,20 +65,33 @@ def reset_parameters(self) -> None:

class GPT(BaseModel):
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
def __init__(self, config: Config) -> None:
def __init__(
self,
config: Config,
kv_cache: Optional[List[KVCache]] = None
) -> None:
nn.Module.__init__(self)
assert config.padded_vocab_size is not None
self.config = config

if kv_cache is not None:
if len(kv_cache) != config.n_layer:
raise ValueError(f"kv_cache length {len(kv_cache)} != {config.n_layer} = config.n_layer")
for kvc in kv_cache:
self._check_kv_cache(config, kvc)
self._default_kv_cache = False
else:
kv_cache = [None] * config.n_layer
self._default_kv_cache = True
self.lm_head = AdapterV2Linear(
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(
Block(config, block_idx)
for block_idx in range(config.n_layer)
Block(config, block_idx, kv_cache=kvc)
for block_idx, kvc in enumerate(kv_cache)
),
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
)
Expand All @@ -103,18 +117,28 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa


class Block(BaseBlock):
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
self.attn = CausalSelfAttention(config, block_idx)
def __init__(
self,
config: Config,
block_idx: int,
kv_cache: Optional[KVCache] = None,
) -> None:
super().__init__(config, block_idx, kv_cache)
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)
self.mlp = config.mlp_class(config)


class CausalSelfAttention(BaseCausalSelfAttention):
"""A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""

# Copy&paste from :class:`model.CausalSelfAttention`
def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
def __init__(
self,
config: Config,
block_idx: int,
kv_cache: Optional[KVCache] = None,
) -> None:
super().__init__(config, block_idx, kv_cache)
# key, query, value projections for all heads, but in a batch
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
self.qkv = AdapterV2Linear(
Expand Down
13 changes: 9 additions & 4 deletions litgpt/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def generate(
self,
prompt: str,
max_new_tokens: int = 50,
prompt_chunksize: int = 1,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: float = 1.0,
Expand All @@ -461,6 +462,11 @@ def generate(
model: The model to use.
prompt: The prompt string to use for generating the samples.
max_new_tokens: The maximum number of new tokens to return.
prompt_chunksize: If even the shortest prompt is longer than the KV
cache, prompts are processed in chunks of this size in the
prefill phase. Once the shortest has been processed to the
end, we proceed with chunk size 1.
Defaults to 1, but larger values are recommended for long prompts.
temperature: Scales the predicted logits by 1 / temperature.
top_k: If specified, only sample among the tokens with the k highest probabilities.
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
Expand Down Expand Up @@ -500,15 +506,12 @@ def generate(
self.kv_cache_initialized = True

# Dynamically grow the kv cache size if necessary
self.model.clear_kv_cache()
if not self.fixed_kv_cache_size and self.prev_generated_seq_length < max_returned_tokens:
tmp_device = self.model.mask_cache.device
self.model.clear_kv_cache()
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device)

else:
for block in self.model.transformer.h:
block.attn.kv_cache.reset_parameters()

self.prev_generated_seq_length = max_returned_tokens
self.model.eval()

Expand All @@ -517,6 +520,7 @@ def iterator():
model=self.model,
prompt=input_ids,
max_returned_tokens=max_returned_tokens,
prompt_chunksize=prompt_chunksize,
temperature=temperature,
top_k=top_k,
top_p=top_p,
Expand All @@ -536,6 +540,7 @@ def iterator():
model=self.model,
prompt=input_ids,
max_returned_tokens=max_returned_tokens,
prompt_chunksize=prompt_chunksize,
temperature=temperature,
top_k=top_k,
top_p=top_p,
Expand Down
Loading

0 comments on commit ff817a9

Please sign in to comment.