Skip to content

Commit

Permalink
Updated code to use top two prediciton as well as fix the prefill code
Browse files Browse the repository at this point in the history
  • Loading branch information
gnpinkert committed Oct 16, 2024
1 parent b26c9e8 commit 9ce4444
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, num_experts: int, w1s_shape: tuple[int, int], w2s_shape: tupl
self.qweight_w1s = torch.nn.Parameter(torch.zeros(shape_tuple + w1s_shape, device='cuda', dtype=torch.int32), requires_grad=False)
self.qweight_w2s = torch.nn.Parameter(torch.zeros(shape_tuple + w2s_shape, device='cuda', dtype=torch.int32), requires_grad=False)
self.qweight_w3s = torch.nn.Parameter(torch.zeros(shape_tuple + w3s_shape, device='cuda', dtype=torch.int32), requires_grad=False)
self.expert_ids: List[int] = [-1, -1, -1, -1]
self.expert_ids: List[int] = [-1, -1]
self.load_predicted_experts_stream = torch.cuda.Stream()
self.current_weight_class: WeightClass = WeightClass.UNKNOWN

Expand Down
31 changes: 18 additions & 13 deletions vllm/model_executor/models/mixtral_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,28 +156,34 @@ def forward(self, hidden_states: torch.Tensor, moe_gpu_buffer: MoeGpuBuffer,

final_hidden_states = None
unique_values = torch.unique(selected_experts)

routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
if is_prefill:
first_index = unique_values[0]
prefill_stream = torch.cuda.Stream()
predicted_expert_list = [-1, -1, -1, -1]
predicted_expert_list = [-1, -1]
for expert_idx in self.expert_indicies:
if expert_idx not in unique_values:
continue

expert_mlp = self.experts[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
keepdim=True)

if expert_idx == 0:
moe_gpu_buffer = self.load_experts([expert_idx], stream=moe_gpu_buffer.load_predicted_experts_stream, moe_gpu_buffer=moe_gpu_buffer)
if expert_idx < 8:
with torch.cuda.stream(prefill_stream):
predicted_expert_list[(expert_idx + 1) % len(predicted_expert_list)] = expert_idx
predicted_expert_list[(expert_idx) % len(predicted_expert_list)] = -1
moe_gpu_buffer = self.load_experts([expert_idx], stream=prefill_stream, moe_gpu_buffer=moe_gpu_buffer)

if expert_idx == first_index:
predicted_expert_list[0] = expert_idx
moe_gpu_buffer = self.load_experts(predicted_expert_list, stream=moe_gpu_buffer.load_predicted_experts_stream, moe_gpu_buffer=moe_gpu_buffer)

with torch.cuda.stream(prefill_stream):
if expert_idx < 7:
next_id = expert_idx + 1
while next_id not in unique_values and next_id < 9:
next_id += 1
if next_id < 8:
prev_index = predicted_expert_list.index(expert_idx)
next_index = 1 - prev_index
predicted_expert_list[next_index] = next_id
predicted_expert_list[prev_index] = -1
moe_gpu_buffer = self.load_experts(predicted_expert_list, stream=prefill_stream, moe_gpu_buffer=moe_gpu_buffer)

current_hidden_states = expert_mlp(hidden_states,
active_expert_idx=expert_idx,
Expand Down Expand Up @@ -406,7 +412,7 @@ def __init__(
self.vocab_size = config.vocab_size

self.predictor = Inference(model=MixtralModelConfig())
self.moe_gpu_buffers = MoeGpuBuffer(num_experts=4, w1s_shape=(256, 28672), w2s_shape=(896, 8192),
self.moe_gpu_buffers = MoeGpuBuffer(num_experts=2, w1s_shape=(256, 28672), w2s_shape=(896, 8192),
w3s_shape=(256, 28672))
self.norm_previous = torch.zeros(62, dtype=torch.float32).to('cuda')
self.moe_events = DebugCudaEvent(topk=2)
Expand Down Expand Up @@ -436,7 +442,6 @@ def forward(
residual = None

is_prefill = input_ids.size(0) > 1

if is_prefill:
for i in range(len(self.layers)):
layer = self.layers[i]
Expand Down

0 comments on commit 9ce4444

Please sign in to comment.