From 71b0b537ecf80b2649543feddeef6a2fce39d23b Mon Sep 17 00:00:00 2001 From: Avshalom Date: Tue, 30 Jul 2024 11:54:22 +0300 Subject: [PATCH 1/4] use FusedMoE layer in Jamba --- vllm/model_executor/models/jamba.py | 156 +++++++++------------------- 1 file changed, 48 insertions(+), 108 deletions(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 3444578227259..8cf2df234c82a 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -1,5 +1,5 @@ # coding=utf-8 -"""Inference-only Jurassic model.""" +"""Inference-only Jamba model.""" from dataclasses import dataclass from typing import Dict, Iterable, List, Optional, Tuple @@ -15,10 +15,9 @@ from vllm.attention.layer import Attention from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce) + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -282,108 +281,50 @@ def forward(self, x): class JambaMoE(nn.Module): - """A tensor-parallel MoE implementation for Mixtral that shards each expert - across all ranks. - Each expert's weights are sharded across all ranks and a fused MoE - kernel is used for the forward pass, and finally we reduce the outputs - across ranks. - """ - - def __init__( - self, - config: JambaConfig, - params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - ): + def __init__(self, + config: JambaConfig, + num_experts: Optional[int] = None, + top_k: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_experts - self.top_k = config.num_experts_per_tok + self.num_total_experts = num_experts or config.num_experts + self.top_k = top_k or config.num_experts_per_tok self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size // self.tp_size - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - - self.router = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype) - - self.ws = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - device="cuda", - dtype=self.params_dtype, - )) - self.w2s = nn.Parameter( - torch.empty( - self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype, - )) + self.intermediate_size = config.intermediate_size - set_weight_attrs( - self.ws, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2s, - { - "weight_loader": self.weight_loader, - }, - ) - - def weight_loader( - self, - param: nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - expert_id: int, - ): - tp_rank = get_tensor_model_parallel_rank() - param_data = param.data - shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("gate_proj.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("up_proj.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("down_proj.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] + if self.num_total_experts > 1: + self.router = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + quant_config=None, + params_dtype=params_dtype) + + self.experts = FusedMoE(self.num_total_experts, + self.top_k, + self.hidden_size, + self.intermediate_size, + tp_size=tp_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + use_grouped_topk=False, + quant_config=quant_config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (batch * sequence_length, n_experts) - router_logits, _ = self.router(hidden_states) - - final_hidden_states = fused_moe( - hidden_states, - self.ws, - self.w2s, - router_logits, - self.top_k, - renormalize= - False, # Mixtral normalize the expert probs to 1. We don't! - inplace=True, - ) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_size) + if self.num_total_experts > 1: + router_logits, _ = self.router(hidden_states) + else: + router_logits = torch.ones((hidden_states.shape[0], 1), + device=hidden_states.device, + dtype=hidden_states.dtype) + hidden_states = self.experts(hidden_states, router_logits) + return hidden_states.view(orig_shape) class JambaMambaDecoderLayer(nn.Module): @@ -917,15 +858,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = [ - # (param_name, weight_name, expert_id) - ( - "ws" if weight_name in ["gate_proj", "up_proj"] else "w2s", - f"experts.{expert_id}.{weight_name}.weight", - expert_id, - ) for expert_id in range(self.config.num_experts) - for weight_name in ["down_proj", "up_proj", "gate_proj"] - ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: @@ -952,7 +891,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, expert_id in expert_params_mapping: + for param_name, weight_name, expert_id, shard_id in expert_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -961,6 +900,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, weight_name, + shard_id=shard_id, expert_id=expert_id) break else: From 6ffe62bb8b296db77fa154e215dda49595e9f4ed Mon Sep 17 00:00:00 2001 From: Avshalom Date: Tue, 30 Jul 2024 12:17:40 +0300 Subject: [PATCH 2/4] fix ruff --- vllm/model_executor/models/jamba.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 8cf2df234c82a..cf407c86acd7d 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -891,7 +891,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, expert_id, shard_id in expert_params_mapping: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) From 061ce18232390266187a01b466036af1044e4c67 Mon Sep 17 00:00:00 2001 From: Avshalom Date: Wed, 31 Jul 2024 10:07:49 +0300 Subject: [PATCH 3/4] initiating ci --- vllm/model_executor/models/jamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index cf407c86acd7d..2654a03d11a6e 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from typing import Dict, Iterable, List, Optional, Tuple -import torch from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update +import torch from torch import nn from torch.nn.parameter import Parameter from transformers import JambaConfig From 78a2b857f5ed21cabda10ece69600ee115ee1c41 Mon Sep 17 00:00:00 2001 From: Avshalom Date: Wed, 31 Jul 2024 10:17:48 +0300 Subject: [PATCH 4/4] resorting imports --- vllm/model_executor/models/jamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 2654a03d11a6e..cf407c86acd7d 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from typing import Dict, Iterable, List, Optional, Tuple +import torch from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update -import torch from torch import nn from torch.nn.parameter import Parameter from transformers import JambaConfig