Skip to content

Commit

Permalink
support muon optimizer (#3234)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang authored Feb 24, 2025
1 parent a0cf96b commit 1bd2060
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 8 deletions.
31 changes: 31 additions & 0 deletions examples/train/optimizer/muon.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# 17GB
# ref: https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py
# `moonshotai/Moonlight-16B-A3B-Instruct` does not support training; here we use `Qwen/Qwen2.5-7B-Instruct` as an example.
CUDA_VISIBLE_DEVICES=0 \
swift sft \
--model Qwen/Qwen2.5-7B-Instruct \
--train_type lora \
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
'AI-ModelScope/alpaca-gpt4-data-en#500' \
'swift/self-cognition#500' \
--optimizer muon \
--torch_dtype bfloat16 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-4 \
--lora_rank 8 \
--lora_alpha 32 \
--target_modules all-linear \
--gradient_accumulation_steps 16 \
--eval_steps 50 \
--save_steps 50 \
--save_total_limit 5 \
--logging_steps 5 \
--max_length 2048 \
--output_dir output \
--system 'You are a helpful assistant.' \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--model_author swift \
--model_name swift-robot
2 changes: 1 addition & 1 deletion swift/llm/model/model/baai.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_model_tokenizer_emu3_chat(model_dir: str,
local_repo_path = kwargs.get('local_repo_path')
if not local_repo_path:
local_repo_path = git_clone_github('https://github.com/baaivision/Emu3.git')
sys.path.append(os.path.join(local_repo_path))
sys.path.append(local_repo_path)
from emu3.mllm.processing_emu3 import Emu3Processor
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)

Expand Down
6 changes: 3 additions & 3 deletions swift/llm/model/model/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def get_model_tokenizer_deepseek_vl(model_dir: str, *args, **kwargs):
local_repo_path = kwargs.get('local_repo_path')
if not local_repo_path:
local_repo_path = git_clone_github('https://github.com/deepseek-ai/DeepSeek-VL')
sys.path.append(os.path.join(local_repo_path))
sys.path.append(local_repo_path)
from deepseek_vl.models import VLChatProcessor
processor = VLChatProcessor.from_pretrained(model_dir)
return _get_deepseek_vl(processor, 'language_model', model_dir, *args, **kwargs)
Expand All @@ -169,7 +169,7 @@ def get_model_tokenizer_deepseek_janus(model_dir: str, *args, **kwargs):
local_repo_path = kwargs.get('local_repo_path')
if not local_repo_path:
local_repo_path = git_clone_github('https://github.com/deepseek-ai/Janus')
sys.path.append(os.path.join(local_repo_path))
sys.path.append(local_repo_path)
from janus.models import VLChatProcessor

processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_dir)
Expand Down Expand Up @@ -210,7 +210,7 @@ def get_model_tokenizer_deepseek_vl2(model_dir: str, *args, **kwargs):
local_repo_path = kwargs.get('local_repo_path')
if not local_repo_path:
local_repo_path = git_clone_github('https://github.com/deepseek-ai/DeepSeek-VL2')
sys.path.append(os.path.join(local_repo_path))
sys.path.append(local_repo_path)
try:
from deepseek_vl2.models import DeepseekVLV2Processor
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/model/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def get_model_tokenizer_omnli(model_dir: str,
local_repo_path = kwargs.get('local_repo_path')
if not local_repo_path:
local_repo_path = git_clone_github('https://github.com/ictnlp/LLaMA-Omni')
sys.path.append(os.path.join(local_repo_path))
sys.path.append(local_repo_path)
from omni_speech.model import OmniSpeech2SLlamaForCausalLM, OmniSpeechLlamaForCausalLM
import whisper
model_config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/model/model/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def get_model_tokenizer_llava(model_dir: str,
else:
repo_path = 'https://github.com/haotian-liu/LLaVA'
local_repo_path = git_clone_github(repo_path)
sys.path.append(os.path.join(local_repo_path))
sys.path.append(local_repo_path)

if llm_model_type == 'mistral':
from llava.model import LlavaMistralForCausalLM, LlavaMistralConfig
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/model/model/mplug.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_model_tokenizer_mplug_owl2(model_dir: str,
if not local_repo_path:
local_repo_path = git_clone_github('https://github.com/X-PLUG/mPLUG-Owl')
local_repo_path = os.path.join(local_repo_path, 'mPLUG-Owl2')
sys.path.append(os.path.join(local_repo_path))
sys.path.append(local_repo_path)

# register
# https://github.com/X-PLUG/mPLUG-Owl/blob/main/mPLUG-Owl2/mplug_owl2/model/modeling_mplug_owl2.py#L447
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/model/model/valley.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_model_tokenizer_valley(model_dir: str,
if not local_repo_path:
repo_path = 'https://github.com/bytedance/Valley.git'
local_repo_path = git_clone_github(repo_path)
sys.path.append(os.path.join(local_repo_path))
sys.path.append(local_repo_path)

if llm_model_type == 'valley':
from transformers.modeling_outputs import CausalLMOutputWithPast
Expand Down
40 changes: 40 additions & 0 deletions swift/plugin/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import sys

from transformers import Trainer

Expand Down Expand Up @@ -53,8 +55,46 @@ def create_lorap_optimizers(args, model, dataset):
return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None


def create_muon_optimizers(args, model, dataset):
from swift.llm import git_clone_github, get_model_arch
if not args.local_repo_path:
args.local_repo_path = git_clone_github('https://github.com/MoonshotAI/Moonlight.git')
sys.path.append(os.path.join(args.local_repo_path, 'examples'))
from toy_train import Muon

# parse args.optim_args
optim_args = {}
if args.optim_args:
for mapping in args.optim_args.replace(' ', '').split(','):
key, value = mapping.split('=')
optim_args[key] = value

model_arch = get_model_arch(model.model_meta.model_arch)
embed_key = model_arch.embedding or 'embed_tokens'
lm_head_key = model_arch.lm_head or 'lm_head'
muon_params = [
p for n, p in model.named_parameters()
if p.requires_grad and p.ndim >= 2 and embed_key not in n and lm_head_key not in n
]
adamw_params = [
p for n, p in model.named_parameters()
if p.requires_grad and not (p.ndim >= 2 and embed_key not in n and lm_head_key not in n)
]

return Muon(
lr=args.learning_rate,
wd=args.weight_decay,
muon_params=muon_params,
adamw_params=adamw_params,
adamw_betas=(args.adam_beta1, args.adam_beta2),
adamw_eps=args.adam_epsilon,
**optim_args,
), None


# Add your own optimizers here, use --optimizer xxx to train
optimizers_map = {
'galore': create_galore_optimizers,
'lorap': create_lorap_optimizers,
'muon': create_muon_optimizers,
}
1 change: 1 addition & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class SwiftArgumentsMixin:
# Value copied from TrainArguments
train_type: Optional[str] = None
optimizer: Optional[str] = None
local_repo_path: Optional[str] = None
galore_config: Optional[GaLoreConfig] = None

def _fix_gradient_checkpointing(self):
Expand Down

0 comments on commit 1bd2060

Please sign in to comment.