Skip to content

Commit

Permalink
⚡ ready to publish
Browse files Browse the repository at this point in the history
  • Loading branch information
KMSorSMS committed Feb 10, 2025
1 parent f892d22 commit 83401db
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 19 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Our vision for KTransformers is to serve as a flexible platform for experimentin
* **Aug 12, 2024**: Support multiple GPU; Support new model: mixtral 8\*7B and 8\*22B; Support q2k, q3k, q5k dequant on gpu.
* **Aug 9, 2024**: Support windows native.

<h2 id="show-cases">🔥 Show Cases</h2>
<h2 id="show-cases">🌟 Show Cases</h2>

<div>
<h3>GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM</h3>
Expand Down
41 changes: 35 additions & 6 deletions doc/en/DeepseekR1_V3_tutorial.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# GPT-4/o1-level Local VSCode Copilot on a Desktop with only 24GB VRAM
# SUMMARY

- **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup.<br>
> **Fed 10, 2025**: Support DeepseekR1 and V3 on single (24GB VRAM)/multi gpu and 382G DRAM, up to 3~64x speedup.<br>
Hi, we're the KTransformers team (formerly known for our local CPU/GPU hybrid inference open source project with DeepSeek-V2).

We've heard your requests for DeepSeek-R1/V3 support—and we're excited to finally deliver!
Apologies for the wait, but we've been cooking up something truly amazing!

Today, we're proud to announce that we not only support DeepSeek-R1/V3, as showcased in the video below:

https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285

Expand All @@ -14,9 +21,10 @@ https://github.com/user-attachments/assets/ebd70bfa-b2c1-4abb-ae3b-296ed38aa285
- Decode Speed (tokens/s):
- KTransfermor: 8.73 (32 cores) → 11.26 (dual-socket, 2×32 cores) → 13.69 (selectively using 6 experts, V0.3 only)
- Compared to 4.51 tokens/s in llama.cpp with 2×32 cores, achieving up to **3.03× speedup**.
- Upcoming Open Source Release:
- AMX optimizations and selective expert activation will be open-sourced in V0.3.
- Currently available only in preview binary distribution, which can be found [here](xxx).


But we're also previewing our upcoming optimizations, including an Intel AMX-accelerated kernel and a selective expert activation method, which will significantly enhance performance. With V0.3-preview, we achieve up to 286 tokens/s for prefill, making it up to **64× faster than llama.cpp** for local inference.
The binary distribution is available now and the source code will come ASAP! Check out the details [here](xxx)


## Prerequisites
Expand Down Expand Up @@ -98,11 +106,32 @@ python ./ktransformers/local_chat.py --model_path <your model path> --gguf_path
<when you see chat, then press enter to load the text prompt_file>
```
The parameters' meaning is the same. But As we use dual socket, we set cpu_infer to 65

### V0.3 Showcase
#### Dual socket version (64 cores)
Our local_chat test command is:
``` shell
python -m ktransformers.local_chat --model_path <your model path> --gguf_path <your gguf path> --prompt_file <your prompt txt file> --cpu_infer 65 --cache_lens 1536
<when you see chat, then press enter to load the text prompt_file>
```
The parameters' meaning is the same with V0.2. But As we use dual socket, we set cpu_infer to 65

## Some Explanations
1. Also we want to make further use of our two NUMA nodes on Xeon Gold cpu.
To avoid the cost of data transfer between nodes, we "copy" the critical matrix on
both nodes which takes more memory consumption but accelerates the prefill and decoding process.
But this method takes huge memory and slow when loading weights, So be patient when loading
and monitor the memory usage. (we are considering to make this method as an option). We are going to optimize this huge memory overhead. Stay tuned~ <br>
and monitor the memory usage. We are going to optimize this huge memory overhead. Stay tuned~ <br>
2. The command args `--cpu_infer 65` specifies how many cores to use (it's ok that it exceeds the physical number,
but it's not the more the better. Adjust it slightly lower to your actual number of cores)<br>
but it's not the more the better. Adjust it slightly lower to your actual number of cores)<br>

3. Why CPU/GPU Hybrid Inference?
DeepSeek's MLA operators are highly computationally intensive. While running everything on CPU is possible, offloading the heavy computations to the GPU results in a massive performance boost.

4. Where Does the Speedup Come From?

- Expert Offload: Unlike traditional layer-based or KVCache offloading (as seen in llama.cpp), we offload the expert computation to the CPU and MLA/KVCache to GPU, aligning perfectly with DeepSeek’s architecture for optimal efficiency.
- Intel AMX Optimization – Our AMX-accelerated kernel is meticulously tuned, running several times faster than existing llama.cpp implementations. We plan to open-source this kernel after cleansing and are considering upstream contributions to llama.cpp.

5. Why Intel CPUs?
Intel is currently the only CPU vendor that supports AMX-like instructions, which delivers significantly better performance compared to AVX-only alternatives.
123 changes: 116 additions & 7 deletions ktransformers/operators/RoPE.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from ktransformers.models.modeling_deepseek import (
DeepseekV2YarnRotaryEmbedding,
DeepseekV2RotaryEmbedding,
yarn_get_mscale,
yarn_linear_ramp_mask,
yarn_find_correction_range
)
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader
Expand Down Expand Up @@ -188,7 +191,33 @@ def load(self):
self.orig_module.mscale_all_dim,
)

class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):
# class DeepSeekV3YarnRotaryEmbedding(BaseInjectedModule, DeepseekV3RotaryEmbedding):
# def __init__(
# self,
# key: str,
# gguf_loader: GGUFLoader,
# config: PretrainedConfig,
# orig_module: nn.Module,
# # device: str = "cuda",
# generate_device: str = "cuda",
# prefill_device: str = "cuda",
# **kwargs,
# ):
# BaseInjectedModule.__init__(
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs
# )
# self.generate_device = generate_device
# self.prefill_device = prefill_device

# def load(self):
# # TODO support perlayer prefill
# self.orig_module.__init__(
# self.config,
# device=self.generate_device
# )
# return

class YarnRotaryEmbeddingV3(BaseInjectedModule):
def __init__(
self,
key: str,
Expand All @@ -205,14 +234,94 @@ def __init__(
)
self.generate_device = generate_device
self.prefill_device = prefill_device

def load(self):
# TODO support perlayer prefill
self.orig_module.__init__(
self.config,
device=self.generate_device
kwargs = {
key: self.config.rope_scaling[key]
for key in [
"original_max_position_embeddings",
"beta_fast",
"beta_slow",
"mscale",
"mscale_all_dim",
]
if key in self.config.rope_scaling
}
self._init(
dim=self.config.qk_rope_head_dim,
max_position_embeddings=self.config.max_position_embeddings,
base=self.config.rope_theta,
device=self.device,
scaling_factor=self.config.rope_scaling["factor"],
**kwargs,
)
return

@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()* self._mscale
sin = emb.sin()* self._mscale
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

def _init(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
original_max_position_embeddings=4096,
beta_fast=32,
beta_slow=1,
mscale=1,
mscale_all_dim=0,
):
self.original_max_position_embeddings = original_max_position_embeddings
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.mscale = mscale
self.mscale_all_dim = mscale_all_dim
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base

freq_extra = 1.0 / (
self.base
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)
freq_inter = 1.0 / (
self.scaling_factor
* self.base
** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
)

low, high = yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
dim,
self.base,
self.original_max_position_embeddings,
)
inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to(
device=device, dtype=torch.float32
)
self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask
self._mscale = float(
yarn_get_mscale(self.scaling_factor, self.mscale)
/ yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)
)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings

class DynamicNTKScalingRotaryEmbedding(
BaseInjectedModule, LlamaDynamicNTKScalingRotaryEmbedding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
- match:
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbeddingV3
class: ktransformers.operators.RoPE.YarnRotaryEmbeddingV3
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
Expand Down

0 comments on commit 83401db

Please sign in to comment.