diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 236feab4b4cb4..edd28a5bcbc6d 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -48,22 +48,21 @@ def __init__( hidden_act: str, ): super().__init__() - # TODO: Merge the gate and down linear layers. - self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size, - bias=False, gather_output=False, - perform_initialization=False) + self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size, + bias=False, gather_output=False, + perform_initialization=False) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, input_is_parallel=True, perform_initialization=False) - self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size, - bias=False, gather_output=False, - perform_initialization=False) assert hidden_act == 'silu' self.act_fn = nn.SiLU() def forward(self, x): - gate, _ = self.gate_proj(x) - up, _ = self.up_proj(x) + gate_up, _ = self.gate_up_proj(x) + gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1)) + gate, up = torch.split(gate_up, 1, dim=-2) + gate = gate.squeeze(dim=-2).contiguous() + up = up.squeeze(dim=-2).contiguous() x = self.act_fn(gate) * up x, _ = self.down_proj(x) return x @@ -85,24 +84,9 @@ def __init__( self.head_dim = hidden_size // self.total_num_heads self.scaling = self.head_dim ** -0.5 - # TODO: Merge the QKV linear layers. - self.q_proj = ColumnParallelLinear( - hidden_size, - self.total_num_heads * self.head_dim, - bias=False, - gather_output=False, - perform_initialization=False, - ) - self.k_proj = ColumnParallelLinear( - hidden_size, - self.total_num_heads * self.head_dim, - bias=False, - gather_output=False, - perform_initialization=False, - ) - self.v_proj = ColumnParallelLinear( + self.qkv_proj = ColumnParallelLinear( hidden_size, - self.total_num_heads * self.head_dim, + 3 * self.total_num_heads * self.head_dim, bias=False, gather_output=False, perform_initialization=False, @@ -124,9 +108,12 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - q, _ = self.q_proj(hidden_states) - k, _ = self.k_proj(hidden_states) - v, _ = self.v_proj(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) + qkv = qkv.reshape(qkv.shape[:-1] + (3, -1)) + q, k, v = torch.split(qkv, 1, dim=-2) + q = q.squeeze(dim=-2).contiguous() + k = k.squeeze(dim=-2).contiguous() + v = v.squeeze(dim=-2).contiguous() k_cache, v_cache = kv_cache attn_output = self.attn( positions, q, k, v, k_cache, v_cache, input_metadata, cache_event) @@ -245,8 +232,7 @@ def forward( return next_tokens _column_parallel_weights = ["embed_tokens.weight", "lm_head.weight", - "q_proj.weight", "k_proj.weight", - "v_proj.weight", "gate_proj.weight", + "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] @@ -254,23 +240,42 @@ def load_weights(self, weights_path: str): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, param in state_dict.items(): - loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, - name))) - for p in self._column_parallel_weights: - if p in name: - shard_size = param.shape[0] - loaded_weight = loaded_weight[ + if "qkv_proj" in name or "gate_up_proj" in name: + if "qkv_proj" in name: + original_name = "qkv_proj" + weight_names = ["q_proj", "k_proj", "v_proj"] + shard_size = param.shape[0] // 3 + else: + original_name = "gate_up_proj" + weight_names = ["gate_proj", "up_proj"] + shard_size = param.shape[0] // 2 + weights_to_concat = [] + for weight_name in weight_names: + weight = np.load(os.path.join( + weights_path, name.replace(original_name, weight_name))) + weights_to_concat.append(weight[ shard_size * tensor_model_parallel_rank - :shard_size * (tensor_model_parallel_rank + 1)] - break - for p in self._row_parallel_weights: - if p in name: - shard_size = param.shape[1] - loaded_weight = loaded_weight[ - :, - shard_size * tensor_model_parallel_rank - :shard_size * (tensor_model_parallel_rank + 1)] - break + :shard_size * (tensor_model_parallel_rank + 1)]) + loaded_weight = torch.from_numpy( + np.concatenate(weights_to_concat, axis=0)) + else: + loaded_weight = torch.from_numpy( + np.load(os.path.join(weights_path, name))) + for p in self._column_parallel_weights: + if p in name: + shard_size = param.shape[0] + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break + for p in self._row_parallel_weights: + if p in name: + shard_size = param.shape[1] + loaded_weight = loaded_weight[ + :, + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break assert param.shape == loaded_weight.shape param.data.copy_(loaded_weight) diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index a74c6a97100b3..eed20ea41ffdd 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -53,16 +53,9 @@ def __init__( self.head_dim = embed_dim // total_num_heads self.scaling = self.head_dim ** -0.5 - # TODO(woosuk): Fuse the three linear layers into one QKV linear layer. - self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, - gather_output=False, - perform_initialization=False) - self.v_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, - gather_output=False, - perform_initialization=False) - self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, - gather_output=False, - perform_initialization=False) + self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias, + gather_output=False, + perform_initialization=False) self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True, perform_initialization=False) @@ -75,16 +68,18 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - q, _ = self.q_proj(hidden_states) - k, _ = self.k_proj(hidden_states) - v, _ = self.v_proj(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) + qkv = qkv.reshape(qkv.shape[:-1] + (3, -1)) + q, k, v = torch.split(qkv, 1, dim=-2) + q = q.squeeze(dim=-2).contiguous() + k = k.squeeze(dim=-2).contiguous() + v = v.squeeze(dim=-2).contiguous() key_cache, value_cache = kv_cache attn_output = self.attn( q, k, v, key_cache, value_cache, input_metadata, cache_event) output, _ = self.out_proj(attn_output) return output - class OPTDecoderLayer(nn.Module): def __init__(self, config: OPTConfig): @@ -262,11 +257,7 @@ def forward( self.lm_head_weight, hidden_states, input_metadata) return next_tokens - _column_parallel_weights = ["embed_tokens.weight", - "q_proj.weight", "k_proj.weight", - "v_proj.weight", "fc1.weight"] - _column_parallel_biases = ["q_proj.bias", "k_proj.bias", - "v_proj.bias", "fc1.bias"] + _column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"] _row_parallel_weights = ["out_proj.weight", "fc2.weight"] def load_weights(self, weights_path: str): @@ -275,24 +266,35 @@ def load_weights(self, weights_path: str): for name, param in state_dict.items(): if "lm_head_weight" in name: continue - loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, - name))) - for p in (self._column_parallel_weights - + self._column_parallel_biases): - if p in name: - shard_size = param.shape[0] - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank - :shard_size * (tensor_model_parallel_rank + 1)] - break - for p in self._row_parallel_weights: - if p in name: - shard_size = param.shape[1] - loaded_weight = loaded_weight[ - :, + if "qkv_proj" in name: + shard_size = param.shape[0] // 3 + weights_to_concat = [] + for weight_name in ["q_proj", "k_proj", "v_proj"]: + weight = np.load(os.path.join( + weights_path, name.replace("qkv_proj", weight_name))) + weights_to_concat.append(weight[ shard_size * tensor_model_parallel_rank - :shard_size * (tensor_model_parallel_rank + 1)] - break + :shard_size * (tensor_model_parallel_rank + 1)]) + loaded_weight = torch.from_numpy( + np.concatenate(weights_to_concat, axis=0)) + else: + loaded_weight = torch.from_numpy( + np.load(os.path.join(weights_path, name))) + for p in self._column_parallel_weights: + if p in name: + shard_size = param.shape[0] + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break + for p in self._row_parallel_weights: + if p in name: + shard_size = param.shape[1] + loaded_weight = loaded_weight[ + :, + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break assert param.shape == loaded_weight.shape param.data.copy_(loaded_weight)