Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hf-Mixtral-to-Metagrone-GroupGmm #2

Merged
merged 1 commit into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions toolkits/model_checkpoints_convertor/mistral/hf2mcore_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,56 @@ def get_megatron_sharded_states(args, tp_size, pp_size, ep_size, pp_rank):
return tp_state_dicts


def gmm_concatenate_experts_and_clean(ep_state_dict):


new_state_dict = {}
# Find all layers and experts by scanning the keys
layers_and_experts = {}

for key in ep_state_dict.keys():
parts = key.split('.')

layer = int(parts[2]) # assumes format 'decoder.layers.X.mlp...'
expert = int(parts[6]) # assumes format '...local_experts.X...'
if layer not in layers_and_experts:
layers_and_experts[layer] = []
if expert not in layers_and_experts[layer]:
layers_and_experts[layer].append(expert)


# Sort experts for each layer to ensure correct order
for layer in layers_and_experts:
layers_and_experts[layer].sort()

# Process each layer
for layer, experts in layers_and_experts.items():
fc2_weights = []
fc1_weights = []

# Collect all fc2 and fc1 weights for the current layer across all experts
for expert in experts:
fc2_key = f'decoder.layers.{layer}.mlp.experts.local_experts.{expert}.linear_fc2.weight'
fc1_key = f'decoder.layers.{layer}.mlp.experts.local_experts.{expert}.linear_fc1.weight'

fc2_weights.append(ep_state_dict[fc2_key])
fc1_weights.append(ep_state_dict[fc1_key])

# Concatenate fc2 weights along the second dimension and fc1 along the first
concatenated_fc2 = torch.cat(fc2_weights, dim=1)
concatenated_fc1 = torch.cat(fc1_weights, dim=0)


# Create new keys in the state dictionary
new_fc2_key = f'decoder.layers.{layer}.mlp.linear_fc2.weight'
new_fc1_key = f'decoder.layers.{layer}.mlp.linear_fc1.weight'

new_state_dict[new_fc2_key] = concatenated_fc2
new_state_dict[new_fc1_key] = concatenated_fc1

return new_state_dict


def megatron_to_transformers_fix_query_key_value_ordering(
param, checkpoint_version, num_splits, num_heads, hidden_size
):
Expand Down Expand Up @@ -514,8 +564,8 @@ def convert_checkpoint_from_transformers_to_megatron(args):
layer_name = f"layers.{layer}.{out_name}.layer_norm_weight"

elif op_name.startswith("post_attention_layernorm") and weight_or_bias == "weight":
out_name = "pre_mlp_layernorm"
layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}"
out_name = "mlp.linear_fc1.layer_norm_weight"
layer_name = f"layers.{layer}.{out_name}"

# handle attention K, V, Q weights
elif op_name.startswith("self_attn.query") and weight_or_bias == "weight":
Expand Down Expand Up @@ -575,9 +625,12 @@ def convert_checkpoint_from_transformers_to_megatron(args):
del params_dict[dense_h_to_4h_2_name]

dense_h_to_4h_name = f'decoder.layers.{layer}.mlp.experts.local_experts.{expert_id}.linear_fc1.weight'
# any reason why we concat upscaling layers weights? (dense_h_to_4h_1_weight, dense_h_to_4h_2_weight)
# may for the computational efficiency ?
params_dict[dense_h_to_4h_name] = \
torch.cat([dense_h_to_4h_1_weight, dense_h_to_4h_2_weight], dim=0)


self_attn_query_name = f"decoder.layers.{layer}.self_attn.query.weight"
query_weight = params_dict[self_attn_query_name]
del params_dict[self_attn_query_name]
Expand Down Expand Up @@ -642,6 +695,7 @@ def convert_checkpoint_from_transformers_to_megatron(args):
expert_group_id = expert_group_mapping[eid]
local_expert_id = expert_local_mapping[eid]
keywords[6] = str(local_expert_id)
kez = [".".join(keywords)]
ep_state_dict[expert_group_id][".".join(keywords)] = output_state_dict[tp_rank]['model'][
key].clone()
output_state_dict[tp_rank]['model'].pop(key)
Expand All @@ -653,8 +707,11 @@ def convert_checkpoint_from_transformers_to_megatron(args):
os.makedirs(save_dir, exist_ok=True)
checkpoint_name = "model_optim_rng.pt"
checkpoint_path = os.path.join(save_dir, checkpoint_name)
output_state_dict[tp_rank]['model'].update(ep_state_dict[ep_rank])
updated_expert_dict = gmm_concatenate_experts_and_clean(ep_state_dict[ep_rank])
output_state_dict[tp_rank]['model'].update(updated_expert_dict)
torch.save(output_state_dict[tp_rank], checkpoint_path)




def convert_checkpoint_from_megatron_to_transformers(args):
Expand Down
6 changes: 6 additions & 0 deletions toolkits/model_checkpoints_convertor/mistral/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
python toolkits/model_checkpoints_convertor/mistral/hf2mcore_mixtral.py --load_path /workspace/checkpoints/teeny-tiny-mixtral --save_path /workspace/checkpoints/teeny-tiny-mixtral-megatrone --target_expert_model_parallel_size 2 \
--target_tensor_model_parallel_size 2 \
--target_pipeline_model_parallel_size 1 \
--target_params_dtype bf16 \
--world_size 4