Skip to content

Commit

Permalink
Merge pull request #487 from kvcache-ai/clean_pr
Browse files Browse the repository at this point in the history
clean PR code and disable flashinfer
  • Loading branch information
Atream authored Feb 19, 2025
2 parents cf4da5f + a529518 commit 89f8218
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 23 deletions.
24 changes: 7 additions & 17 deletions ktransformers/operators/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,10 @@ def __init__(self,
def get_absorbed(self) -> Tuple[torch.Tensor, torch.Tensor]:
if not (hasattr(self, 'q_absorb') and hasattr(self, 'out_absorb')):
kv_b_proj = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank)
q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].reshape(-1, self.kv_lora_rank)
out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].reshape(-1, self.kv_lora_rank)
self.q_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.qk_nope_head_dim,
bias=False, dtype=q_absorb.dtype, device=q_absorb.device)
self.q_absorb.weight.data = q_absorb
self.out_absorb = nn.Linear(self.kv_lora_rank, self.num_heads * self.v_head_dim,
bias=False, dtype=out_absorb.dtype, device=out_absorb.device)
self.out_absorb.weight.data = out_absorb
#del self.orig_module.kv_b_proj
q_absorb = self.q_absorb.weight.view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
out_absorb = self.out_absorb.weight.view(self.num_heads, self.v_head_dim, self.kv_lora_rank)
return q_absorb, out_absorb
self.q_absorb = kv_b_proj[:, :self.qk_nope_head_dim, :].view(self.num_heads, self.qk_nope_head_dim, self.kv_lora_rank)
self.out_absorb = kv_b_proj[:, self.qk_nope_head_dim:, :].view(self.num_heads, self.v_head_dim, self.kv_lora_rank)

return self.q_absorb, self.out_absorb

def forward_chunck(
self,
Expand Down Expand Up @@ -105,7 +97,7 @@ def forward_chunck(
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
Expand All @@ -129,8 +121,6 @@ def forward_chunck(
# compressed_kv [pages, page_size, 1, self.kv_lora_rank]

q_absorb, out_absorb = self.get_absorbed()
# if hasattr(self.orig_module, 'kv_b_proj'):
# del self.orig_module.kv_b_proj

# q_nope [bsz, self.num_heads, q_len, self.qk_nope_head_dim]
# q_pe [bsz, self.num_heads, q_len, self.qk_rope_head_dim]
Expand Down Expand Up @@ -227,7 +217,7 @@ def forward_linux_triton(
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
f"The cache structure has changed since transformer version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
Expand Down Expand Up @@ -379,7 +369,7 @@ def forward_linux_flashinfer(
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
f"The cache structure has changed since version transformer verision v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
Expand Down
2 changes: 1 addition & 1 deletion ktransformers/operators/flashinfer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

try:
import flashinfer
flashinfer_enabled = True
flashinfer_enabled = False # disabled now, TODO:use new version of flashinfer and enable
print("found flashinfer")

except ImportError:
Expand Down
10 changes: 5 additions & 5 deletions ktransformers/server/backend/interfaces/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,13 +381,13 @@ async def inference(self, local_messages, thread_id: str):

self.profiler.create_and_start_timer("prefill")


if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think

for t in self.prefill(input_ids, self.check_is_new(thread_id)):
# output think token after prefill done
if Config().user_force_think:
think = '<think>\n'
print(think, end="",flush=True)
yield think
if t is not None:
print(t, end="",flush=True)
yield t
Expand Down

0 comments on commit 89f8218

Please sign in to comment.