diff --git a/deepspeed/module_inject/fusedqkv_utils.py b/deepspeed/module_inject/fusedqkv_utils.py index d61e78ab8d0e..bed0607dcb43 100644 --- a/deepspeed/module_inject/fusedqkv_utils.py +++ b/deepspeed/module_inject/fusedqkv_utils.py @@ -4,7 +4,7 @@ # DeepSpeed Team import torch from deepspeed.utils.logging import warning_once -from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads +from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list, get_num_kv_heads, get_n_embd import re @@ -17,7 +17,7 @@ def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0): def require_tp_fused_qkvw(name, mp_size): - fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack'] + fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv', 'self_attn.W_pack', 'c_attn'] if mp_size == 1: return False @@ -38,6 +38,7 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index): "MptBlock": 'glmtype', "BaichuanLayer": 'glmtype', "DecoderLayer": 'glmtype', + "GPTBigCodeBlock": 'bigcodetype' } def _codegen_type_transpose(input, mp_size, codegen_mp_num=4): @@ -74,6 +75,14 @@ def _bloom_type_transpose(input, mp_size): split_fusedqkv = input.split(get_shard_size_list(shape[0], mp_size), dim=0) return split_fusedqkv[gpu_index] + def _bigcode_type_transpose(input, mp_size): + n_embd = get_n_embd() + q = input[:n_embd] + kv = input[n_embd:] + shape = q.shape + split_q = q.split(get_shard_size_list(shape[0], mp_size), dim=0) + return torch.cat((split_q[gpu_index], kv), dim=0) + def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None): # suppose num_heads=n, q(n)_w means the n-th q head linear weight, the weight format are as following @@ -87,6 +96,8 @@ def _transpose_fused_qkvw(src, mp_size, fused_qkv_type=None): return _codegen_type_transpose(src, mp_size) elif fused_qkv_type == 'glmtype': return _glm_type_transpose(src, mp_size) + elif fused_qkv_type == 'bigcodetype': + return _bigcode_type_transpose(src, mp_size) raise ValueError("unknown fused_qkv_type") diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 5b7d2209d89e..c3648ad5385b 100644 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -16,7 +16,7 @@ from .auto_tp import AutoTP, ReplaceWithTensorSlicing, Loading from deepspeed import comm as dist -from deepspeed.module_inject.tp_shard import set_num_kv_heads +from deepspeed.module_inject.tp_shard import set_num_kv_heads, set_n_embd from .load_checkpoint import load_model_with_checkpoint import time @@ -278,6 +278,18 @@ def replace_wo_policy(module, all_reduce_linears, prefix="", state_dict=None): # 4. When we have num_kv_heads defined, uneven division is possible, otherwise enforce even division set_num_kv_heads(num_kv_heads) + # 4.1 Get n_embd + n_embd = None + multi_query_n_embd_names = ['n_embd'] + for name in multi_query_n_embd_names: + if hasattr(model_config, name): + n_embd = getattr(model_config, name) + if n_embd != None: + break + + # 4.2 set n_embd + set_n_embd(n_embd) + # 5. Set linear policies _autotp.update_linear_policies() diff --git a/deepspeed/module_inject/tp_shard.py b/deepspeed/module_inject/tp_shard.py index 302b3c33c953..259927f70308 100644 --- a/deepspeed/module_inject/tp_shard.py +++ b/deepspeed/module_inject/tp_shard.py @@ -12,6 +12,11 @@ def set_num_kv_heads(num): num_kv_heads = num +def set_n_embd(num): + global n_embd + n_embd = num + + def get_num_kv_heads(): global num_kv_heads return num_kv_heads @@ -32,6 +37,11 @@ def get_shard_size(total_size, mp_size, rank=None): assert False, f"Number of attention heads ({total_size}) must be divisible by mp_size ({mp_size})" +def get_n_embd(): + global n_embd + return n_embd + + def get_shard_size_list(total_size, mp_size): shard_sizes = [] for i in range(mp_size):