From 28df307bc8ee7f8933857eb894451642848ed5e2 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 30 Mar 2023 16:49:08 +0000 Subject: [PATCH 1/3] Merge QKV for OPT --- cacheflow/models/attention.py | 2 +- cacheflow/models/opt.py | 49 ++++++++++++++++++++--------------- 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 7f24670b7eaa3..60226341bc693 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -90,7 +90,7 @@ def forward( cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # [num_tokens, num_heads * head_size] # Pre-allocate the output tensor. - output = torch.empty_like(query) + output = torch.zeros_like(query) # Prune out paddings if any. query = query[:input_metadata.num_valid_tokens] diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 3a7e6a1103855..be5709a38ee4b 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -53,16 +53,9 @@ def __init__( self.head_dim = embed_dim // total_num_heads self.scaling = self.head_dim**-0.5 - # TODO(woosuk): Fuse the three linear layers into one QKV linear layer. - self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, - gather_output=False, - perform_initialization=False) - self.v_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, - gather_output=False, - perform_initialization=False) - self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias, - gather_output=False, - perform_initialization=False) + self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias, + gather_output=False, + perform_initialization=False) self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias, input_is_parallel=True, perform_initialization=False) @@ -76,16 +69,18 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - q, _ = self.q_proj(hidden_states) - k, _ = self.k_proj(hidden_states) - v, _ = self.v_proj(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) + qkv = qkv.reshape(qkv.shape[:-1] + (-1, 3)) + q, k, v = torch.split(qkv, 1, dim=-1) + q = q.squeeze(dim=-1).contiguous() + k = k.squeeze(dim=-1).contiguous() + v = v.squeeze(dim=-1).contiguous() key_cache, value_cache = kv_cache attn_output = self.attn( q, k, v, key_cache, value_cache, input_metadata, cache_event) output, _ = self.out_proj(attn_output) return output - class OPTDecoderLayer(nn.Module): def __init__(self, config: OPTConfig): @@ -263,11 +258,9 @@ def forward( self.lm_head_weight, hidden_states, input_metadata) return next_tokens - _column_parallel_weights = ["embed_tokens.weight", - "q_proj.weight", "k_proj.weight", - "v_proj.weight", "fc1.weight"] - _column_parallel_biases = ["q_proj.bias", "k_proj.bias", - "v_proj.bias", "fc1.bias"] + _column_parallel_weights = ["embed_tokens.weight", "qkv_proj.weight", + "fc1.weight"] + _column_parallel_biases = ["qkv_proj.bias", "fc1.bias"] _row_parallel_weights = ["out_proj.weight", "fc2.weight"] def load_weights(self, weights_path: str): @@ -276,8 +269,22 @@ def load_weights(self, weights_path: str): for name, param in state_dict.items(): if "lm_head_weight" in name: continue - loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, - name))) + if "qkv_proj.weight" in name: + q_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "q_proj"))) + k_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "k_proj"))) + v_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "v_proj"))) + loaded_weight = np.stack([q_weight, k_weight, v_weight]).transpose(1, 0, 2) + loaded_weight = torch.from_numpy(loaded_weight.reshape(-1, loaded_weight.shape[-1])) + elif "qkv_proj.bias" in name: + q_bias = np.load(os.path.join(weights_path, name.replace("qkv_proj", "q_proj"))) + k_bias = np.load(os.path.join(weights_path, name.replace("qkv_proj", "k_proj"))) + v_bias = np.load(os.path.join(weights_path, name.replace("qkv_proj", "v_proj"))) + loaded_weight = np.stack([q_bias, k_bias, v_bias]).transpose(1, 0) + loaded_weight = torch.from_numpy(loaded_weight.reshape(-1)) + else: + loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, + name))) + for p in (self._column_parallel_weights + self._column_parallel_biases): if p in name: From 2e417f5d6fea6a63f0e098af8cd71125a9988720 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 30 Mar 2023 17:03:58 +0000 Subject: [PATCH 2/3] merge qkv for llama --- cacheflow/models/attention.py | 2 +- cacheflow/models/llama.py | 63 +++++++++++++++++------------------ cacheflow/models/opt.py | 3 +- 3 files changed, 32 insertions(+), 36 deletions(-) diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 60226341bc693..7f24670b7eaa3 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -90,7 +90,7 @@ def forward( cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # [num_tokens, num_heads * head_size] # Pre-allocate the output tensor. - output = torch.zeros_like(query) + output = torch.empty_like(query) # Prune out paddings if any. query = query[:input_metadata.num_valid_tokens] diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 4ddbc698eb789..2df27fcd3e7b8 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -90,22 +90,21 @@ def __init__( hidden_act: str, ): super().__init__() - # TODO: Merge the gate and down linear layers. - self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size, - bias=False, gather_output=False, - perform_initialization=False) + self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size, + bias=False, gather_output=False, + perform_initialization=False) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, input_is_parallel=True, perform_initialization=False) - self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size, - bias=False, gather_output=False, - perform_initialization=False) assert hidden_act == 'silu' self.act_fn = nn.SiLU() def forward(self, x): - gate, _ = self.gate_proj(x) - up, _ = self.up_proj(x) + gate_up, _ = self.gate_up_proj(x) + gate_up = gate_up.reshape(gate_up.shape[:-1] + (-1, 2)) + gate, up = torch.split(gate_up, 1, dim=-1) + gate = gate.squeeze(dim=-1).contiguous() + up = up.squeeze(dim=-1).contiguous() x = self.act_fn(gate) * up x, _ = self.down_proj(x) return x @@ -127,24 +126,9 @@ def __init__( self.head_dim = hidden_size // self.total_num_heads self.scaling = self.head_dim ** -0.5 - # TODO: Merge the QKV linear layers. - self.q_proj = ColumnParallelLinear( - hidden_size, - self.total_num_heads * self.head_dim, - bias=False, - gather_output=False, - perform_initialization=False, - ) - self.k_proj = ColumnParallelLinear( + self.qkv_proj = ColumnParallelLinear( hidden_size, - self.total_num_heads * self.head_dim, - bias=False, - gather_output=False, - perform_initialization=False, - ) - self.v_proj = ColumnParallelLinear( - hidden_size, - self.total_num_heads * self.head_dim, + 3 * self.total_num_heads * self.head_dim, bias=False, gather_output=False, perform_initialization=False, @@ -168,9 +152,12 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - q, _ = self.q_proj(hidden_states) - k, _ = self.k_proj(hidden_states) - v, _ = self.v_proj(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) + qkv = qkv.reshape(qkv.shape[:-1] + (-1, 3)) + q, k, v = torch.split(qkv, 1, dim=-1) + q = q.squeeze(dim=-1).contiguous() + k = k.squeeze(dim=-1).contiguous() + v = v.squeeze(dim=-1).contiguous() # Apply rotrary embedding. # TODO: Optimize. @@ -299,8 +286,7 @@ def forward( return next_tokens _column_parallel_weights = ["embed_tokens.weight", "lm_head.weight", - "q_proj.weight", "k_proj.weight", - "v_proj.weight", "gate_proj.weight", + "qkv_proj.weight", "gate_proj.weight", "up_proj.weight"] _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] @@ -308,8 +294,19 @@ def load_weights(self, weights_path: str): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, param in state_dict.items(): - loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, - name))) + if "qkv_proj.weight" in name: + q_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "q_proj"))) + k_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "k_proj"))) + v_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "v_proj"))) + loaded_weight = np.stack([q_weight, k_weight, v_weight]).transpose(1, 0, 2) + loaded_weight = torch.from_numpy(loaded_weight.reshape(-1, loaded_weight.shape[-1])) + elif "gate_up_proj.weight" in name: + gate_weight = np.load(os.path.join(weights_path, name.replace("gate_up_proj", "gate_proj"))) + up_weight = np.load(os.path.join(weights_path, name.replace("gate_up_proj", "up_proj"))) + loaded_weight = np.stack([gate_weight, up_weight]).transpose(1, 0, 2) + loaded_weight = torch.from_numpy(loaded_weight.reshape(-1, loaded_weight.shape[-1])) + else: + loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, name))) for p in self._column_parallel_weights: if p in name: shard_size = param.shape[0] diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index be5709a38ee4b..4759c215427fc 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -282,8 +282,7 @@ def load_weights(self, weights_path: str): loaded_weight = np.stack([q_bias, k_bias, v_bias]).transpose(1, 0) loaded_weight = torch.from_numpy(loaded_weight.reshape(-1)) else: - loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, - name))) + loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, name))) for p in (self._column_parallel_weights + self._column_parallel_biases): From 06f23ffac90cc9a5b65eda23855db87f70f7787a Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 31 Mar 2023 07:24:10 +0000 Subject: [PATCH 3/3] fix the code according to woosuk's comment --- cacheflow/models/llama.py | 80 +++++++++++++++++++++------------------ cacheflow/models/opt.py | 72 +++++++++++++++++------------------ 2 files changed, 78 insertions(+), 74 deletions(-) diff --git a/cacheflow/models/llama.py b/cacheflow/models/llama.py index 2df27fcd3e7b8..b28fc94b37164 100644 --- a/cacheflow/models/llama.py +++ b/cacheflow/models/llama.py @@ -101,10 +101,10 @@ def __init__( def forward(self, x): gate_up, _ = self.gate_up_proj(x) - gate_up = gate_up.reshape(gate_up.shape[:-1] + (-1, 2)) - gate, up = torch.split(gate_up, 1, dim=-1) - gate = gate.squeeze(dim=-1).contiguous() - up = up.squeeze(dim=-1).contiguous() + gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1)) + gate, up = torch.split(gate_up, 1, dim=-2) + gate = gate.squeeze(dim=-2).contiguous() + up = up.squeeze(dim=-2).contiguous() x = self.act_fn(gate) * up x, _ = self.down_proj(x) return x @@ -153,11 +153,11 @@ def forward( cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - qkv = qkv.reshape(qkv.shape[:-1] + (-1, 3)) - q, k, v = torch.split(qkv, 1, dim=-1) - q = q.squeeze(dim=-1).contiguous() - k = k.squeeze(dim=-1).contiguous() - v = v.squeeze(dim=-1).contiguous() + qkv = qkv.reshape(qkv.shape[:-1] + (3, -1)) + q, k, v = torch.split(qkv, 1, dim=-2) + q = q.squeeze(dim=-2).contiguous() + k = k.squeeze(dim=-2).contiguous() + v = v.squeeze(dim=-2).contiguous() # Apply rotrary embedding. # TODO: Optimize. @@ -294,34 +294,42 @@ def load_weights(self, weights_path: str): tensor_model_parallel_rank = get_tensor_model_parallel_rank() state_dict = self.state_dict() for name, param in state_dict.items(): - if "qkv_proj.weight" in name: - q_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "q_proj"))) - k_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "k_proj"))) - v_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "v_proj"))) - loaded_weight = np.stack([q_weight, k_weight, v_weight]).transpose(1, 0, 2) - loaded_weight = torch.from_numpy(loaded_weight.reshape(-1, loaded_weight.shape[-1])) - elif "gate_up_proj.weight" in name: - gate_weight = np.load(os.path.join(weights_path, name.replace("gate_up_proj", "gate_proj"))) - up_weight = np.load(os.path.join(weights_path, name.replace("gate_up_proj", "up_proj"))) - loaded_weight = np.stack([gate_weight, up_weight]).transpose(1, 0, 2) - loaded_weight = torch.from_numpy(loaded_weight.reshape(-1, loaded_weight.shape[-1])) - else: - loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, name))) - for p in self._column_parallel_weights: - if p in name: - shard_size = param.shape[0] - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank - :shard_size * (tensor_model_parallel_rank + 1)] - break - for p in self._row_parallel_weights: - if p in name: - shard_size = param.shape[1] - loaded_weight = loaded_weight[ - :, + if "qkv_proj" in name or "gate_up_proj" in name: + if "qkv_proj" in name: + original_name = "qkv_proj" + weight_names = ["q_proj", "k_proj", "v_proj"] + shard_size = param.shape[0] // 3 + else: + original_name = "gate_up_proj" + weight_names = ["gate_proj", "up_proj"] + shard_size = param.shape[0] // 2 + weights_to_concat = [] + for weight_name in weight_names: + weight = np.load(os.path.join( + weights_path, name.replace(original_name, weight_name))) + weights_to_concat.append(weight[ shard_size * tensor_model_parallel_rank - :shard_size * (tensor_model_parallel_rank + 1)] - break + :shard_size * (tensor_model_parallel_rank + 1)]) + loaded_weight = torch.from_numpy( + np.concatenate(weights_to_concat, axis=0)) + else: + loaded_weight = torch.from_numpy( + np.load(os.path.join(weights_path, name))) + for p in self._column_parallel_weights: + if p in name: + shard_size = param.shape[0] + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break + for p in self._row_parallel_weights: + if p in name: + shard_size = param.shape[1] + loaded_weight = loaded_weight[ + :, + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break assert param.shape == loaded_weight.shape param.data.copy_(loaded_weight) diff --git a/cacheflow/models/opt.py b/cacheflow/models/opt.py index 4759c215427fc..3a5e968e3ddb6 100644 --- a/cacheflow/models/opt.py +++ b/cacheflow/models/opt.py @@ -70,11 +70,11 @@ def forward( cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - qkv = qkv.reshape(qkv.shape[:-1] + (-1, 3)) - q, k, v = torch.split(qkv, 1, dim=-1) - q = q.squeeze(dim=-1).contiguous() - k = k.squeeze(dim=-1).contiguous() - v = v.squeeze(dim=-1).contiguous() + qkv = qkv.reshape(qkv.shape[:-1] + (3, -1)) + q, k, v = torch.split(qkv, 1, dim=-2) + q = q.squeeze(dim=-2).contiguous() + k = k.squeeze(dim=-2).contiguous() + v = v.squeeze(dim=-2).contiguous() key_cache, value_cache = kv_cache attn_output = self.attn( q, k, v, key_cache, value_cache, input_metadata, cache_event) @@ -258,9 +258,7 @@ def forward( self.lm_head_weight, hidden_states, input_metadata) return next_tokens - _column_parallel_weights = ["embed_tokens.weight", "qkv_proj.weight", - "fc1.weight"] - _column_parallel_biases = ["qkv_proj.bias", "fc1.bias"] + _column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"] _row_parallel_weights = ["out_proj.weight", "fc2.weight"] def load_weights(self, weights_path: str): @@ -269,37 +267,35 @@ def load_weights(self, weights_path: str): for name, param in state_dict.items(): if "lm_head_weight" in name: continue - if "qkv_proj.weight" in name: - q_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "q_proj"))) - k_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "k_proj"))) - v_weight = np.load(os.path.join(weights_path, name.replace("qkv_proj", "v_proj"))) - loaded_weight = np.stack([q_weight, k_weight, v_weight]).transpose(1, 0, 2) - loaded_weight = torch.from_numpy(loaded_weight.reshape(-1, loaded_weight.shape[-1])) - elif "qkv_proj.bias" in name: - q_bias = np.load(os.path.join(weights_path, name.replace("qkv_proj", "q_proj"))) - k_bias = np.load(os.path.join(weights_path, name.replace("qkv_proj", "k_proj"))) - v_bias = np.load(os.path.join(weights_path, name.replace("qkv_proj", "v_proj"))) - loaded_weight = np.stack([q_bias, k_bias, v_bias]).transpose(1, 0) - loaded_weight = torch.from_numpy(loaded_weight.reshape(-1)) - else: - loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path, name))) - - for p in (self._column_parallel_weights - + self._column_parallel_biases): - if p in name: - shard_size = param.shape[0] - loaded_weight = loaded_weight[ + if "qkv_proj" in name: + shard_size = param.shape[0] // 3 + weights_to_concat = [] + for weight_name in ["q_proj", "k_proj", "v_proj"]: + weight = np.load(os.path.join( + weights_path, name.replace("qkv_proj", weight_name))) + weights_to_concat.append(weight[ shard_size * tensor_model_parallel_rank - :shard_size * (tensor_model_parallel_rank + 1)] - break - for p in self._row_parallel_weights: - if p in name: - shard_size = param.shape[1] - loaded_weight = loaded_weight[ - :, - shard_size * tensor_model_parallel_rank - :shard_size * (tensor_model_parallel_rank + 1)] - break + :shard_size * (tensor_model_parallel_rank + 1)]) + loaded_weight = torch.from_numpy( + np.concatenate(weights_to_concat, axis=0)) + else: + loaded_weight = torch.from_numpy( + np.load(os.path.join(weights_path, name))) + for p in self._column_parallel_weights: + if p in name: + shard_size = param.shape[0] + loaded_weight = loaded_weight[ + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break + for p in self._row_parallel_weights: + if p in name: + shard_size = param.shape[1] + loaded_weight = loaded_weight[ + :, + shard_size * tensor_model_parallel_rank + :shard_size * (tensor_model_parallel_rank + 1)] + break assert param.shape == loaded_weight.shape param.data.copy_(loaded_weight)