Skip to content

Commit

Permalink
Merge pull request #333 from kvcache-ai/feat_experts_gpu
Browse files Browse the repository at this point in the history
toy support for experts on GPU, no CUDA Graph
  • Loading branch information
Atream authored Feb 15, 2025
2 parents ae8da01 + 8ed8eb2 commit c5f036e
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 68 deletions.
4 changes: 3 additions & 1 deletion doc/en/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552
1. local_chat.py: You can increase the context window size by setting `--max_new_tokens` to a larger value.
2. server: Increase the `--cache_lens' to a larger value.
2. Move more weights to the GPU.
Refer to the ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-marlin.yaml
Refer to the ktransformers/optimize/optimize_rules/DeepSeek-V3-Chat-multi-gpu-4.yaml
```yaml
- match:
name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$" # inject experts in layer 4~10 as marlin expert
Expand All @@ -39,6 +39,8 @@ from-https://github.com/kvcache-ai/ktransformers/issues/129#issue-2842799552
You can modify layer as you want, eg. `name: "^model\\.layers\\.([4-10])\\.mlp\\.experts$"` to `name: "^model\\.layers\\.([4-12])\\.mlp\\.experts$"` to move more weights to the GPU.

> Note: The first matched rule in yaml will be applied. For example, if you have two rules that match the same layer, only the first rule's replacement will be valid.
> Note:Currently, executing experts on the GPU will conflict with CUDA Graph. Without CUDA Graph, there will be a significant slowdown. Therefore, unless you have a substantial amount of VRAM (placing a single layer of experts for DeepSeek-V3/R1 on the GPU requires at least 5.6GB of VRAM), we do not recommend enabling this feature. We are actively working on optimization.
> Note KExpertsTorch is untested.


### Q: If I don't have enough VRAM, but I have multiple GPUs, how can I utilize them?
Expand Down
28 changes: 14 additions & 14 deletions ktransformers/ktransformers_ext/cuda/custom_gguf/dequant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include <c10/cuda/CUDAGuard.h>

__global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
for(int i=0;i<blk_size;i++){
float scale = scales[block_id];
output[block_id * blk_size + i] = scale * qs[block_id * blk_size + i];
Expand All @@ -37,8 +37,8 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
}

__global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);

const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 80)));
Expand Down Expand Up @@ -72,10 +72,10 @@ __global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size

__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {

int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
for (long long block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);

uint32_t aux[4];
Expand Down Expand Up @@ -128,8 +128,8 @@ __global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size


__global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
// const uint8_t * q = data[i].qs;
const uint8_t * q = (uint8_t*)(data + block_id * 144 + 16);
Expand All @@ -152,8 +152,8 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
}

__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id = global_idx; block_id < num_blocks; block_id += blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);

const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 0)));
Expand Down Expand Up @@ -181,8 +181,8 @@ __global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size
}

__global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 208)));

Expand Down Expand Up @@ -215,8 +215,8 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};

__global__ void dequantize_iq4_xs_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
long long global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (long long block_id=global_idx; block_id<num_blocks; block_id+=blockDim.x * gridDim.x) {
float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size)));
const uint16_t scales_h = *(reinterpret_cast<uint16_t*>(data + block_id * blk_size + 2));
Expand Down
136 changes: 93 additions & 43 deletions ktransformers/operators/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import sys, os
from ktransformers.operators.base_operator import BaseInjectedModule
from tqdm import tqdm

sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
Expand Down Expand Up @@ -225,6 +226,7 @@ def unload(self):
return

def load_weights(self, override_key: str | None = None, device: str = "cpu"):
# TODO: support Bias
res = {}
if override_key is not None:
keys = override_key
Expand Down Expand Up @@ -288,6 +290,8 @@ def __init__(
self.act_fn = ACT2FN[config.hidden_act]
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
self.device = device
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size

# create empty marlin experts according to the number of experts per token
# up
self.up_projs = [KLinearMarlin(key+ "." + "ffn_up_exps", gguf_loader, config, device=device) for i in range(self.expert_num)]
Expand All @@ -299,17 +303,34 @@ def __init__(
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
if device is None: device = self.device
assert device.lower() != "cpu", "Marlin experts can only be loaded on GPU"
if w is None: w = self.load_weights()[self.key]

if isinstance(w, dict):
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in range(self.expert_num):
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
self.loaded_experts_idx.append(i)
if w is None:
w = self.load_weights()
load_by_experts = True

if load_by_experts:
if isinstance(w, dict):
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in tqdm(range(self.expert_num), desc=f"Dequanting and quanting for KExpertsMarlin {self.key}"):
up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", self.up, i, self.elements_per_tensor, device=self.device)
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", self.gate, i, self.elements_per_tensor, device=self.device)
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", self.down, i, self.elements_per_tensor, device=self.device)

self.up_projs[i].load(nn.Parameter(up_weights), device=device)
self.gate_projs[i].load(nn.Parameter(gate_weights), device=device)
self.down_projs[i].load(nn.Parameter(down_weights), device=device)
self.loaded_experts_idx.append(i)
else:
if isinstance(w, dict):
self.gate = w["gate"]
self.up = (w["up"])
self.down = (w["down"])
for i in range(self.expert_num):
self.up_projs[i].load(nn.Parameter(self.up[i,...]), device=device)
self.gate_projs[i].load(nn.Parameter(self.gate[i,...]), device=device)
self.down_projs[i].load(nn.Parameter(self.down[i,...]), device=device)
self.loaded_experts_idx.append(i)
return

def unload(self):
Expand All @@ -329,20 +350,13 @@ def load_weights(self, override_key: str | None = None):
gate = None
up = None
down = None
gate_type = None
up_type = None
down_type = None

for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.load_gguf_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.load_gguf_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.load_gguf_tensor(key + ".ffn_down_exps.weight")
gate_type = self.gguf_loader.tensor_info[key + ".ffn_gate_exps.weight"]["ggml_type"]
up_type = self.gguf_loader.tensor_info[key + ".ffn_up_exps.weight"]["ggml_type"]
down_type = self.gguf_loader.tensor_info[key + ".ffn_down_exps.weight"]["ggml_type"]
# tensors = self.load_multi(key, [".ffn_gate_exps.weight", ".ffn_up_exps.weight", ".ffn_down_exps.weight"])
res = {key:{"gate": nn.Parameter(gate), "up": nn.Parameter(up), "down": nn.Parameter(down), "gate_type": gate_type, "up_type": up_type, "down_type": down_type}}
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
res = {"gate": gate, "up": up, "down": down}
return res

def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -381,6 +395,7 @@ def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.T

return final_hidden_states.to(dtype=org_dtype, device=org_device)

# untested, CUDA OOM
class KExpertsTorch(KExpertsBase):
expert_num: int
loaded_experts_idx: list[int]
Expand All @@ -402,26 +417,65 @@ def __init__(
# self.loaded_experts_idx = []
self.act_fn = ACT2FN[config.hidden_act]
self.device = device
self.gate = None
self.up = None
self.donw = None
self.elements_per_tensor = config.moe_intermediate_size * config.hidden_size
self.gate = [None for _ in range(self.expert_num)]
self.up = [None for _ in range(self.expert_num)]
self.down = [None for _ in range(self.expert_num)]
self.dtype = torch.get_default_dtype()

def load(self, w: dict | nn.Parameter | tuple | None = None, device: str | None = None, warmup: bool = False):
if device is None: device = self.device
if w is None: w = self.load_weights(device=device)[self.key]

if isinstance(w, dict):
self.gate = w["gate"].to(device=device, dtype=self.dtype)
self.up = w["up"].to(device=device, dtype=self.dtype)
self.down = w["down"].to(device=device, dtype=self.dtype)
if w is None:
w = self.load_weights()
load_by_experts = True

if load_by_experts:
if isinstance(w, dict):
for i in tqdm(range(self.expert_num), desc=f"Dequanting for KExpertsTorch {self.key}"):
up_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_up_exps.weight", w["up"], i, self.elements_per_tensor, device=self.device)
gate_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_gate_exps.weight", w["gate"], i, self.elements_per_tensor, device=self.device)
down_weights = self.gguf_loader.load_expert_tensor(self.key + ".ffn_down_exps.weight", w["down"], i, self.elements_per_tensor, device=self.device)

self.up[i] = up_weights
self.gate[i] = gate_weights
self.down[i] = down_weights
else:
if isinstance(w, dict):
for i in range(self.expert_num):
self.gate[i] = w["gate"][i, ...].to(device=device, dtype=self.dtype)
self.up[i] = w["up"][i, ...].to(device=device, dtype=self.dtype)
self.down[i] = w["down"][i, ...].to(device=device, dtype=self.dtype)

self.up = torch.cat(self.gate, dim=0)
self.gate = torch.cat(self.gate, dim=0)
self.down = torch.cat(self.gate, dim=0)
return

def unload(self):
if self.gate is not None:
self.gate = None
self.up = None
self.down = None

def load_weights(self, override_key: str | None = None):
res = {}
if override_key is not None:
keys = override_key
else:
keys = [self.key]

gate = None
up = None
down = None

for key in keys:
if key + ".ffn_gate_exps.weight" in self.gguf_loader.tensor_info:
gate = self.gguf_loader.get_mmap_tensor(key + ".ffn_gate_exps.weight")
up = self.gguf_loader.get_mmap_tensor(key + ".ffn_up_exps.weight")
down = self.gguf_loader.get_mmap_tensor(key + ".ffn_down_exps.weight")
res = {"gate": gate, "up": up, "down": down}
return res

def forward(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor:

org_device = hidden_states_cpu.device
Expand Down Expand Up @@ -582,7 +636,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

if isinstance(self.experts, KExpertsBase):
y = (
self.moe_on_cpuinfer(
self.moe_kexperts(
hidden_states_expert, selected_experts_expert, routing_weights_expert
)
.view(*orig_shape)
Expand All @@ -601,8 +655,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return y, router_logits

@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs

Expand Down Expand Up @@ -672,7 +725,7 @@ def forward(self, hidden_states):
y_ = self.shared_experts(identity).squeeze(0)

if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
Expand All @@ -692,8 +745,7 @@ def forward(self, hidden_states):
return y

@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs

Expand Down Expand Up @@ -773,7 +825,7 @@ def forward(self, hidden_states):
y_ = self.shared_experts(identity).squeeze(0)

if isinstance(self.experts, KExpertsBase):
y = self.moe_on_cpuinfer(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
y = self.moe_kexperts(hidden_states, topk_idx, topk_weight).view(*orig_shape).to(device=hidden_states.device)
elif hidden_states.size(0) > 10:
# TODO may bugs here
y = (
Expand All @@ -793,8 +845,7 @@ def forward(self, hidden_states):
return y

@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs

Expand Down Expand Up @@ -881,7 +932,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

if isinstance(self.experts, KExpertsBase):
y = (
self.moe_on_cpuinfer(
self.moe_kexperts(
hidden_states_expert, selected_experts_expert, routing_weights_expert
)
.view(*orig_shape)
Expand All @@ -900,8 +951,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return y, router_logits

@torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = torch.empty_like(x)
def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
outs = self.experts(x, topk_ids, topk_weight)
return outs

Expand Down
Loading

0 comments on commit c5f036e

Please sign in to comment.