Skip to content

Commit

Permalink
Merge QKV into one linear layer (vllm-project#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuohan123 authored Apr 2, 2023
1 parent f4354de commit c8f2711
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 83 deletions.
99 changes: 52 additions & 47 deletions cacheflow/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,21 @@ def __init__(
hidden_act: str,
):
super().__init__()
# TODO: Merge the gate and down linear layers.
self.gate_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=False, input_is_parallel=True,
perform_initialization=False)
self.up_proj = ColumnParallelLinear(hidden_size, intermediate_size,
bias=False, gather_output=False,
perform_initialization=False)
assert hidden_act == 'silu'
self.act_fn = nn.SiLU()

def forward(self, x):
gate, _ = self.gate_proj(x)
up, _ = self.up_proj(x)
gate_up, _ = self.gate_up_proj(x)
gate_up = gate_up.reshape(gate_up.shape[:-1] + (2, -1))
gate, up = torch.split(gate_up, 1, dim=-2)
gate = gate.squeeze(dim=-2).contiguous()
up = up.squeeze(dim=-2).contiguous()
x = self.act_fn(gate) * up
x, _ = self.down_proj(x)
return x
Expand All @@ -70,24 +69,9 @@ def __init__(
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim ** -0.5

# TODO: Merge the QKV linear layers.
self.q_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.k_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.v_proj = ColumnParallelLinear(
self.qkv_proj = ColumnParallelLinear(
hidden_size,
self.total_num_heads * self.head_dim,
3 * self.total_num_heads * self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
Expand All @@ -109,9 +93,12 @@ def forward(
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
k_cache, v_cache = kv_cache
attn_output = self.attn(
positions, q, k, v, k_cache, v_cache, input_metadata, cache_event)
Expand Down Expand Up @@ -230,32 +217,50 @@ def forward(
return next_tokens

_column_parallel_weights = ["embed_tokens.weight", "lm_head.weight",
"q_proj.weight", "k_proj.weight",
"v_proj.weight", "gate_proj.weight",
"qkv_proj.weight", "gate_proj.weight",
"up_proj.weight"]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]

def load_weights(self, weights_path: str):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, param in state_dict.items():
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
if "qkv_proj" in name or "gate_up_proj" in name:
if "qkv_proj" in name:
original_name = "qkv_proj"
weight_names = ["q_proj", "k_proj", "v_proj"]
shard_size = param.shape[0] // 3
else:
original_name = "gate_up_proj"
weight_names = ["gate_proj", "up_proj"]
shard_size = param.shape[0] // 2
weights_to_concat = []
for weight_name in weight_names:
weight = np.load(os.path.join(
weights_path, name.replace(original_name, weight_name)))
weights_to_concat.append(weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
:shard_size * (tensor_model_parallel_rank + 1)])
loaded_weight = torch.from_numpy(
np.concatenate(weights_to_concat, axis=0))
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break

assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
Expand Down
74 changes: 38 additions & 36 deletions cacheflow/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,9 @@ def __init__(
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim ** -0.5

# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.v_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.qkv_proj = ColumnParallelLinear(embed_dim, 3 * embed_dim, bias=bias,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
input_is_parallel=True,
perform_initialization=False)
Expand All @@ -75,16 +68,18 @@ def forward(
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
qkv = qkv.reshape(qkv.shape[:-1] + (3, -1))
q, k, v = torch.split(qkv, 1, dim=-2)
q = q.squeeze(dim=-2).contiguous()
k = k.squeeze(dim=-2).contiguous()
v = v.squeeze(dim=-2).contiguous()
key_cache, value_cache = kv_cache
attn_output = self.attn(
q, k, v, key_cache, value_cache, input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output


class OPTDecoderLayer(nn.Module):

def __init__(self, config: OPTConfig):
Expand Down Expand Up @@ -262,11 +257,7 @@ def forward(
self.lm_head_weight, hidden_states, input_metadata)
return next_tokens

_column_parallel_weights = ["embed_tokens.weight",
"q_proj.weight", "k_proj.weight",
"v_proj.weight", "fc1.weight"]
_column_parallel_biases = ["q_proj.bias", "k_proj.bias",
"v_proj.bias", "fc1.bias"]
_column_parallel_weights = ["embed_tokens.weight", "fc1.weight", "fc1.bias"]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]

def load_weights(self, weights_path: str):
Expand All @@ -275,24 +266,35 @@ def load_weights(self, weights_path: str):
for name, param in state_dict.items():
if "lm_head_weight" in name:
continue
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
name)))
for p in (self._column_parallel_weights
+ self._column_parallel_biases):
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
if "qkv_proj" in name:
shard_size = param.shape[0] // 3
weights_to_concat = []
for weight_name in ["q_proj", "k_proj", "v_proj"]:
weight = np.load(os.path.join(
weights_path, name.replace("qkv_proj", weight_name)))
weights_to_concat.append(weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
:shard_size * (tensor_model_parallel_rank + 1)])
loaded_weight = torch.from_numpy(
np.concatenate(weights_to_concat, axis=0))
else:
loaded_weight = torch.from_numpy(
np.load(os.path.join(weights_path, name)))
for p in self._column_parallel_weights:
if p in name:
shard_size = param.shape[0]
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break
for p in self._row_parallel_weights:
if p in name:
shard_size = param.shape[1]
loaded_weight = loaded_weight[
:,
shard_size * tensor_model_parallel_rank
:shard_size * (tensor_model_parallel_rank + 1)]
break

assert param.shape == loaded_weight.shape
param.data.copy_(loaded_weight)
Expand Down

0 comments on commit c8f2711

Please sign in to comment.