Skip to content

Commit

Permalink
Merge pull request #382 from ceerRep/server-prefix-cache
Browse files Browse the repository at this point in the history
fix server and add prefix cache for server
  • Loading branch information
Atream authored Feb 19, 2025
2 parents ea75849 + 584c7d5 commit cf4da5f
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 89 deletions.
14 changes: 13 additions & 1 deletion ktransformers/models/custom_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,19 @@ def reset(self):
self.key_cache[layer_idx].zero_()
if self.value_cache[layer_idx] is not None:
self.value_cache[layer_idx].zero_()
self.past_tokens[layer_idx] = 0

def remove_suffix(self, start_pos):
for layer_idx in range(len(self.key_cache)):
# In-place ops prevent breaking the static address
if self.is_MLA:
k_cache = self.key_cache[layer_idx]
k_cache.view(-1, k_cache.shape[-1])[start_pos:].zero_()
else:
self.key_cache[layer_idx][..., start_pos:, :].zero_()
self.value_cache[layer_idx][..., start_pos:, :].zero_()
self.past_tokens[layer_idx] = start_pos

def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
"""Returns the maximum shape of the cache."""
return self.max_cache_len
return self.max_cache_len
68 changes: 46 additions & 22 deletions ktransformers/operators/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def forward_chunck(
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]

q_absorb, out_absorb = self.get_absorbed()
if hasattr(self.orig_module, 'kv_b_proj'):
del self.orig_module.kv_b_proj
# if hasattr(self.orig_module, 'kv_b_proj'):
# del self.orig_module.kv_b_proj

# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
Expand Down Expand Up @@ -222,6 +222,16 @@ def forward_linux_triton(
compressed_kv = self.kv_a_layernorm(compressed_kv)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)

kv_seq_len = q_len
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
Expand Down Expand Up @@ -293,26 +303,28 @@ def forward_linux_triton(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
k_pe.squeeze(0)
compressed_kv.squeeze(0)
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
k_pe.unsqueeze(0)
compressed_kv.unsqueeze(0)

k_pe = k_pe[:, :q_len]
compressed_kv = compressed_kv[:, :q_len]
compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
compressed_kv, k_pe = torch.split(
compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)
k_pe = k_pe[:, :kv_seq_len]
compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)
compressed_kv = compressed_kv[:, :kv_seq_len]
kv = (
self.kv_b_proj(compressed_kv)
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe

key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)

value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)

attn_output = flash_attn_func(
Expand Down Expand Up @@ -362,6 +374,16 @@ def forward_linux_flashinfer(
compressed_kv = self.kv_a_layernorm(compressed_kv)
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim)
compressed_kv = compressed_kv.view(bsz, q_len, 1, self.kv_lora_rank)

kv_seq_len = q_len
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

cos, sin = self.rotary_emb(q_pe, position_ids)
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, unsqueeze_dim=2)
Expand Down Expand Up @@ -441,26 +463,28 @@ def forward_linux_flashinfer(
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
k_pe.squeeze(0)
compressed_kv.squeeze(0)
past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
k_pe.unsqueeze(0)
compressed_kv.unsqueeze(0)

k_pe = k_pe[:, :q_len]
compressed_kv = compressed_kv[:, :q_len]
compressed_kv_with_k_pe, _ = past_key_value.update(compressed_kv, k_pe, self.layer_idx, cache_kwargs)
compressed_kv, k_pe = torch.split(
compressed_kv_with_k_pe, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
k_pe = k_pe.view(bsz, -1, self.qk_rope_head_dim)
k_pe = k_pe[:, :kv_seq_len]
compressed_kv = compressed_kv.view(bsz, -1, self.kv_lora_rank)
compressed_kv = compressed_kv[:, :kv_seq_len]
kv = (
self.kv_b_proj(compressed_kv)
.view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
.view(bsz, kv_seq_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
)
k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
query_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
query_states[:, :, :, : self.qk_nope_head_dim] = q_nope
query_states[:, :, :, self.qk_nope_head_dim :] = q_pe

key_states = k_pe.new_empty(bsz, q_len, self.num_heads, self.q_head_dim)
key_states = k_pe.new_empty(bsz, kv_seq_len, self.num_heads, self.q_head_dim)
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe.view(bsz, kv_seq_len, 1, -1)

value_states = value_states.view(bsz, q_len, self.num_heads, self.v_head_dim)
value_states = value_states.view(bsz, kv_seq_len, self.num_heads, self.v_head_dim)
value_states_padded = torch.nn.functional.pad(value_states, [0, query_states.shape[-1] - value_states.shape[-1]], value=0)

attn_output = flash_attn_func(
Expand Down
12 changes: 5 additions & 7 deletions ktransformers/server/api/openai/endpoints/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
from fastapi.requests import Request
from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import chat_stream_response
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject
from ktransformers.server.schemas.endpoints.chat import ChatCompletionCreate,ChatCompletionChunk,ChatCompletionObject, Usage
from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config

router = APIRouter()

models = [
{"id": "0", "name": "ktranformers-model"},
]

@router.get('/models', tags=['openai'])
async def list_models():
return models
return [{"id": Config().model_name, "name": Config().model_name}]


@router.post('/chat/completions', tags=['openai'])
Expand All @@ -36,7 +33,8 @@ async def inner():
yield chunk
return chat_stream_response(request,inner())
else:
comp = ChatCompletionObject(id=id,object='chat.completion.chunk',created=int(time()))
comp = ChatCompletionObject(id=id,object='chat.completion',created=int(time()))
comp.usage = Usage(completion_tokens=1, prompt_tokens=1, total_tokens=2)
async for token in interface.inference(input_message,id):
comp.append_token(token)
return comp
3 changes: 2 additions & 1 deletion ktransformers/server/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def parse_args(self):
# user config
parser.add_argument("--user_secret_key", type=str, default=self.cfg.user_secret_key)
parser.add_argument("--user_algorithm", type=str, default=self.cfg.user_algorithm)
parser.add_argument("--force_think", type=bool, default=self.cfg.user_force_think)
parser.add_argument("--force_think", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.user_force_think)
parser.add_argument("--use_cuda_graph", action=argparse.BooleanOptionalAction, type=bool, default=self.cfg.use_cuda_graph)

# web config
parser.add_argument("--web_cross_domain", type=bool, default=self.cfg.web_cross_domain)
Expand Down
71 changes: 47 additions & 24 deletions ktransformers/server/backend/interfaces/ktransformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device


warm_uped = False

class KTransformersThreadContext(TransformersThreadContext):
pass

Expand Down Expand Up @@ -74,13 +76,13 @@ def __init__(self, args: ConfigArgs = default_args):
self._infer_lock = asyncio.Lock()

def decode_one_tokens(self):
global warm_uped

device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device("blk.0.self_attn", device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
global warm_uped
torch.cuda.set_device(torch_device)
if self.args.use_cuda_graph and warm_uped == True:

if warm_uped and self.args.use_cuda_graph:
if not hasattr(self, "cuda_graph_runner"):
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(
Expand Down Expand Up @@ -127,34 +129,54 @@ def decode_one_tokens(self):
@torch.no_grad
def prefill(self, input_ids: torch.Tensor, is_new: bool):
input_ids_length = input_ids.shape[-1]
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")

device = self.device_map.get("blk.0.self_attn", {}).get("generate_device", "cuda:0")
device = "cuda:0" if device == "cuda" else device

if is_new:
self.cache.reset()
self.ever_generated_ids.clear()
former_seq_length = 0
self.seq_length = input_ids_length
self.generated_ids = torch.zeros(
self.args.batch_size,
self.seq_length + self.args.max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
else:
logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
same_prefix = 0
flat_input_ids = input_ids.flatten()

if getattr(self, 'generated_ids', None) is None:
self.generated_ids = torch.zeros(
self.args.batch_size,
input_ids.shape[-1] + self.args.max_new_tokens + 1,
dtype=torch.int,
device=self.args.device,
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)
self.seq_length = 1

flat_prev_ids = self.generated_ids.flatten()
for i in range(min(self.seq_length, flat_input_ids.shape[0]) - 1):
if flat_input_ids[i] == flat_prev_ids[i]:
same_prefix += 1
else:
break

logger.debug(f"same prefix len: {same_prefix}")
self.cache.remove_suffix(same_prefix)
self.seq_length = same_prefix
self.generated_ids = self.generated_ids[..., :same_prefix]
input_ids = input_ids[..., same_prefix:]
input_ids_length = input_ids.shape[-1]

self.ever_generated_ids.clear()
self.profiler.set_counter("prefill", input_ids_length)
logger.debug(f"input_ids: {input_ids.shape}")

logger.debug(f"generate_ids: {self.generated_ids.shape}")
former_seq_length = self.seq_length
self.seq_length += input_ids_length
expected_length = self.seq_length + self.args.max_new_tokens + 1
delta_length = expected_length - self.generated_ids.shape[-1]
if delta_length > 0:
new_generate_ids = torch.zeros(
self.args.batch_size, delta_length, dtype=torch.int, device=self.args.device
)
self.generated_ids = torch.cat([self.generated_ids, new_generate_ids], dim=-1)

logger.debug(f"cache position: {former_seq_length} to {self.seq_length}")
cache_position = torch.arange(former_seq_length, self.seq_length, device=device)
self.generated_ids[:, cache_position] = input_ids.to(self.args.device).to(torch.int)
Expand All @@ -176,6 +198,7 @@ def prefill(self, input_ids: torch.Tensor, is_new: bool):
else:
logits = self.model(inputs_embeds=inputs_embeds, return_dict=False)[0]

self.prepare_logits_wrapper(input_ids, device)
next_token = self.logits_to_token(logits[0, -1, :])
yield self.append_new_tokens(next_token)

Expand All @@ -187,4 +210,4 @@ def active_cache_position(self):
async def inference(self, local_messages, thread_id: str):
async with self._infer_lock:
async for v in super().inference(local_messages, thread_id):
yield v
yield v
Loading

0 comments on commit cf4da5f

Please sign in to comment.