Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] use FusedMoE layer in Jamba #6935

Merged
merged 4 commits into from
Jul 31, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 49 additions & 108 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding=utf-8
"""Inference-only Jurassic model."""
"""Inference-only Jamba model."""
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple

Expand All @@ -15,10 +15,9 @@
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
Expand Down Expand Up @@ -282,108 +281,50 @@ def forward(self, x):


class JambaMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.

Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""

def __init__(
self,
config: JambaConfig,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self,
config: JambaConfig,
num_experts: Optional[int] = None,
top_k: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.num_total_experts = num_experts or config.num_experts
self.top_k = top_k or config.num_experts_per_tok
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // self.tp_size

if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype

self.router = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype)

self.ws = nn.Parameter(
torch.empty(
self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
dtype=self.params_dtype,
))
self.w2s = nn.Parameter(
torch.empty(
self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype,
))
self.intermediate_size = config.intermediate_size

set_weight_attrs(
self.ws,
{
"weight_loader": self.weight_loader,
},
)
set_weight_attrs(
self.w2s,
{
"weight_loader": self.weight_loader,
},
)

def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
expert_id: int,
):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("gate_proj.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("up_proj.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("down_proj.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if self.num_total_experts > 1:
self.router = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
quant_config=None,
params_dtype=params_dtype)

self.experts = FusedMoE(self.num_total_experts,
self.top_k,
self.hidden_size,
self.intermediate_size,
tp_size=tp_size,
params_dtype=params_dtype,
reduce_results=True,
renormalize=False,
use_grouped_topk=False,
quant_config=quant_config)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.router(hidden_states)

final_hidden_states = fused_moe(
hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=
False, # Mixtral normalize the expert probs to 1. We don't!
inplace=True,
)

if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

return final_hidden_states.view(num_tokens, hidden_size)
if self.num_total_experts > 1:
router_logits, _ = self.router(hidden_states)
else:
router_logits = torch.ones((hidden_states.shape[0], 1),
device=hidden_states.device,
dtype=hidden_states.dtype)
hidden_states = self.experts(hidden_states, router_logits)
return hidden_states.view(orig_shape)


class JambaMambaDecoderLayer(nn.Module):
Expand Down Expand Up @@ -917,15 +858,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("gate_up_proj", "up_proj", 1),
]

expert_params_mapping = [
# (param_name, weight_name, expert_id)
(
"ws" if weight_name in ["gate_proj", "up_proj"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
) for expert_id in range(self.config.num_experts)
for weight_name in ["down_proj", "up_proj", "gate_proj"]
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)

params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
Expand All @@ -952,7 +891,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name, expert_id in expert_params_mapping:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
Expand All @@ -961,6 +901,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader(param,
loaded_weight,
weight_name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
Expand Down
Loading