Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support stepfun-ai/Step-Audio-Chat #3127

Merged
merged 9 commits into from
Feb 16, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/Instruction/支持的模型和数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@
|[BAAI/Emu3-Gen](https://modelscope.cn/models/BAAI/Emu3-Gen)|emu3_gen|emu3_gen|-|t2i|[BAAI/Emu3-Gen](https://huggingface.co/BAAI/Emu3-Gen)|
|[BAAI/Emu3-Chat](https://modelscope.cn/models/BAAI/Emu3-Chat)|emu3_chat|emu3_chat|transformers>=4.44.0|vision|[BAAI/Emu3-Chat](https://huggingface.co/BAAI/Emu3-Chat)|
|[stepfun-ai/GOT-OCR2_0](https://modelscope.cn/models/stepfun-ai/GOT-OCR2_0)|got_ocr2|got_ocr2|-|vision|[stepfun-ai/GOT-OCR2_0](https://huggingface.co/stepfun-ai/GOT-OCR2_0)|
|[stepfun-ai/Step-Audio-Chat](https://modelscope.cn/models/stepfun-ai/Step-Audio-Chat)|step_audio|step_audio|funasr, sox, conformer, openai-whisper, librosa|audio|[stepfun-ai/Step-Audio-Chat](https://huggingface.co/stepfun-ai/Step-Audio-Chat)|
|[LLM-Research/Phi-3-vision-128k-instruct](https://modelscope.cn/models/LLM-Research/Phi-3-vision-128k-instruct)|phi3_vision|phi3_vision|transformers>=4.36|vision|[microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)|
|[LLM-Research/Phi-3.5-vision-instruct](https://modelscope.cn/models/LLM-Research/Phi-3.5-vision-instruct)|phi3_vision|phi3_vision|transformers>=4.36|vision|[microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)|
|[AI-ModelScope/Florence-2-base-ft](https://modelscope.cn/models/AI-ModelScope/Florence-2-base-ft)|florence|florence|-|vision|[microsoft/Florence-2-base-ft](https://huggingface.co/microsoft/Florence-2-base-ft)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ The table below introduces the models integrated with ms-swift:
|[BAAI/Emu3-Gen](https://modelscope.cn/models/BAAI/Emu3-Gen)|emu3_gen|emu3_gen|-|t2i|[BAAI/Emu3-Gen](https://huggingface.co/BAAI/Emu3-Gen)|
|[BAAI/Emu3-Chat](https://modelscope.cn/models/BAAI/Emu3-Chat)|emu3_chat|emu3_chat|transformers>=4.44.0|vision|[BAAI/Emu3-Chat](https://huggingface.co/BAAI/Emu3-Chat)|
|[stepfun-ai/GOT-OCR2_0](https://modelscope.cn/models/stepfun-ai/GOT-OCR2_0)|got_ocr2|got_ocr2|-|vision|[stepfun-ai/GOT-OCR2_0](https://huggingface.co/stepfun-ai/GOT-OCR2_0)|
|[stepfun-ai/Step-Audio-Chat](https://modelscope.cn/models/stepfun-ai/Step-Audio-Chat)|step_audio|step_audio|funasr, sox, conformer, openai-whisper, librosa|audio|[stepfun-ai/Step-Audio-Chat](https://huggingface.co/stepfun-ai/Step-Audio-Chat)|
|[LLM-Research/Phi-3-vision-128k-instruct](https://modelscope.cn/models/LLM-Research/Phi-3-vision-128k-instruct)|phi3_vision|phi3_vision|transformers>=4.36|vision|[microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)|
|[LLM-Research/Phi-3.5-vision-instruct](https://modelscope.cn/models/LLM-Research/Phi-3.5-vision-instruct)|phi3_vision|phi3_vision|transformers>=4.36|vision|[microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)|
|[AI-ModelScope/Florence-2-base-ft](https://modelscope.cn/models/AI-ModelScope/Florence-2-base-ft)|florence|florence|-|vision|[microsoft/Florence-2-base-ft](https://huggingface.co/microsoft/Florence-2-base-ft)|
Expand Down
3 changes: 0 additions & 3 deletions swift/llm/argument/rlhf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,6 @@ def _set_default_ddp_config():

def _init_grpo(self):
if self.rlhf_type == 'grpo':
if self.use_lmdeploy:
# In case trl GRPOTrainer need use_vllm
self.use_vllm = True
if self.use_vllm or self.use_lmdeploy:
os.environ['USE_FAST_INFERENCE'] = '1'
self._set_default_ddp_config()
Expand Down
1 change: 1 addition & 0 deletions swift/llm/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class MLLMModelType:
emu3_gen = 'emu3_gen'
emu3_chat = 'emu3_chat'
got_ocr2 = 'got_ocr2'
step_audio = 'step_audio'

phi3_vision = 'phi3_vision'
florence = 'florence'
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/model/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import (baai, baichuan, bert, codefuse, deepseek, gemma, glm, internlm, llama, llava, llm, mamba, microsoft,
minicpm, minimax, mistral, mllm, mplug, openbuddy, qwen, skywork, telechat, valley, yi)
minicpm, minimax, mistral, mllm, mplug, openbuddy, qwen, skywork, stepfun, telechat, valley, yi)
21 changes: 0 additions & 21 deletions swift/llm/model/model/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Any, Dict

import torch
from transformers import AutoModel
from transformers.dynamic_module_utils import get_class_from_dynamic_module

from swift.llm import TemplateType
Expand All @@ -18,26 +17,6 @@
logger = get_logger()


def get_model_tokenizer_got_ocr2(*args, **kwargs):
kwargs['automodel_class'] = AutoModel
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
return model, tokenizer


register_model(
ModelMeta(
MLLMModelType.got_ocr2, [
ModelGroup([
Model('stepfun-ai/GOT-OCR2_0', 'stepfun-ai/GOT-OCR2_0'),
]),
],
TemplateType.got_ocr2,
get_model_tokenizer_got_ocr2,
model_arch=ModelArch.got_ocr2,
architectures=['GOTQwenForCausalLM'],
tags=['vision']))


def get_model_tokenizer_idefics(model_dir: str, *args, **kwargs):
from transformers import AutoModelForVision2Seq
kwargs['automodel_class'] = kwargs['automodel_class'] or AutoModelForVision2Seq
Expand Down
64 changes: 64 additions & 0 deletions swift/llm/model/model/stepfun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import sys

from transformers import AutoModel

from swift.llm import TemplateType
from ..constant import MLLMModelType
from ..model_arch import ModelArch
from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
from ..utils import git_clone_github, safe_snapshot_download


def get_model_tokenizer_got_ocr2(*args, **kwargs):
kwargs['automodel_class'] = AutoModel
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
return model, tokenizer


register_model(
ModelMeta(
MLLMModelType.got_ocr2, [
ModelGroup([
Model('stepfun-ai/GOT-OCR2_0', 'stepfun-ai/GOT-OCR2_0'),
]),
],
TemplateType.got_ocr2,
get_model_tokenizer_got_ocr2,
model_arch=ModelArch.got_ocr2,
architectures=['GOTQwenForCausalLM'],
tags=['vision']))


def get_model_tokenizer_step_audio(*args, **kwargs):
local_repo_path = kwargs.get('local_repo_path')
if not local_repo_path:
local_repo_path = git_clone_github('https://github.com/stepfun-ai/Step-Audio.git')
sys.path.append(local_repo_path)
if not os.path.exists('speakers'):
shutil.copytree(os.path.join(local_repo_path, 'speakers'), 'speakers')
from tokenizer import StepAudioTokenizer
from tts import StepAudioTTS
encoder_path = safe_snapshot_download('stepfun-ai/Step-Audio-Tokenizer')
decoder_path = safe_snapshot_download('stepfun-ai/Step-Audio-TTS-3B')
model, tokenizer = get_model_tokenizer_with_flash_attn(*args, **kwargs)
model.encoder = StepAudioTokenizer(encoder_path)
model.decoder = StepAudioTTS(decoder_path, model.encoder)
return model, tokenizer


register_model(
ModelMeta(
MLLMModelType.step_audio, [
ModelGroup([
Model('stepfun-ai/Step-Audio-Chat', 'stepfun-ai/Step-Audio-Chat'),
]),
],
TemplateType.step_audio,
get_model_tokenizer_step_audio,
model_arch=ModelArch.step_audio,
architectures=['Step1ForCausalLM'],
requires=['funasr', 'sox', 'conformer', 'openai-whisper', 'librosa'],
tags=['audio']))
4 changes: 4 additions & 0 deletions swift/llm/model/model_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class MLLMModelArch:
idefics3 = 'idefics3'

got_ocr2 = 'got_ocr2'
step_audio = 'step_audio'

ovis1_6 = 'ovis1_6'
molmo = 'molmo'
emu3_chat = 'emu3_chat'
Expand Down Expand Up @@ -479,6 +481,8 @@ def register_model_arch(model_arch: ModelKeys, *, exist_ok: bool = False) -> Non
vision_tower='model.vision_tower_high',
))

register_model_arch(MultiModelKeys(MLLMModelArch.step_audio, language_model='model', generator='decoder'))

register_model_arch(
MultiModelKeys(
MLLMModelArch.llama3_2_vision,
Expand Down
1 change: 1 addition & 0 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def _seq_cls_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
encoded['labels'] = int(inputs.label)
return encoded

@torch.inference_mode()
def encode(self,
inputs: Union[TemplateInputs, Dict[str, Any], InferRequest],
return_template_inputs: bool = False) -> Dict[str, Any]:
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/template/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class MLLMTemplateType:
emu3_gen = 'emu3_gen'

got_ocr2 = 'got_ocr2'
step_audio = 'step_audio'

idefics3 = 'idefics3'
pixtral = 'pixtral'
paligemma = 'paligemma'
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/template/template/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import (deepseek, emu3, gemma, glm, got_ocr, idefics3, internlm, internvl, llama, llava, llm, megrez, microsoft,
minicpm, minimax, molmo, mplug, openbuddy, pixtral, qwen, valley, yi)
from . import (deepseek, emu3, gemma, glm, idefics3, internlm, internvl, llama, llava, llm, megrez, microsoft, minicpm,
minimax, molmo, mplug, openbuddy, pixtral, qwen, stepfun, valley, yi)
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

from ..base import Template
from ..constant import MLLMTemplateType
from ..register import register_template
from ..register import TemplateMeta, register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context
from ..vision_utils import load_file
from .qwen import QwenTemplateMeta


Expand Down Expand Up @@ -66,3 +67,26 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
template_cls=GOT_OCR2Template,
placeholder_tokens=['<imgpad>'],
))


class StepAudioTemplate(Template):
use_model = True

def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
inputs: StdTemplateInputs) -> List[Context]:
from utils import load_audio
audio_wav, sr = load_audio(load_file(inputs.audios[index]))
audio_tokens = self.model.encoder(audio_wav, sr)
return audio_tokens


register_template(
TemplateMeta(
MLLMTemplateType.step_audio,
template_cls=StepAudioTemplate,
prefix=['<s>'],
prompt=['<|BOT|>human\n{{QUERY}}<|EOT|><|BOT|>assistant\n'],
system_prefix=['<s><|BOT|>system\n{{SYSTEM}}<|EOT|>'],
chat_sep=['<|EOT|>'],
suffix=['<|EOT|>'],
))
3 changes: 3 additions & 0 deletions swift/trainers/rlhf_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class GRPOConfig(GRPOArgumentsMixin, SwiftArgumentsMixin, HfGRPOConfig):
repetition_penalty: Optional[float] = None

def __post_init__(self):
if self.use_lmdeploy:
# In case trl GRPOTrainer need use_vllm
self.use_vllm = True
super().__post_init__()
if self.cosine_max_len is None:
self.cosine_max_len = self.max_completion_length
8 changes: 6 additions & 2 deletions swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,15 @@ def __init__(self,
}
})
self.engine.default_template = self.template
self._last_loaded_step = 0 # tag to avoid useless loading during grad accumulation

# When using vLLM, the main process is responsible for loading the model weights. This can cause process
# desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
# synchronize all processes after vLLM has been fully initialized.
self.accelerator.wait_for_everyone()
else:
from swift.llm import PtEngine
self.engine = PtEngine.from_model_template(self.model, self.template, max_batch_size=0) # 0: no limit
self._last_loaded_step = 0
self.accelerator.wait_for_everyone()
self.request_config = RequestConfig(
max_tokens=args.max_completion_length,
temperature=args.temperature,
Expand Down
38 changes: 23 additions & 15 deletions tests/test_align/test_template/test_audio.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'


def _infer_model(pt_engine, system=None):
def _infer_model(pt_engine, system=None, messages=None, audios=None):
seed_everything(42)
request_config = RequestConfig(max_tokens=128, temperature=0)
messages = []
if system is not None:
messages += [{'role': 'system', 'content': system}]
messages += [{'role': 'user', 'content': '你好'}]
resp = pt_engine.infer([{'messages': messages}], request_config=request_config)
response = resp[0].choices[0].message.content
messages += [{'role': 'assistant', 'content': response}, {'role': 'user', 'content': '<audio>这段语音说了什么'}]
resp = pt_engine.infer([{
'messages': messages,
'audios': ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/weather.wav']
}],
request_config=request_config)
if messages is None:
messages = []
if system is not None:
messages += [{'role': 'system', 'content': system}]
messages += [{'role': 'user', 'content': '你好'}]
resp = pt_engine.infer([{'messages': messages}], request_config=request_config)
response = resp[0].choices[0].message.content
messages += [{'role': 'assistant', 'content': response}, {'role': 'user', 'content': '<audio>这段语音说了什么'}]
else:
messages = messages.copy()
if audios is None:
audios = ['http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/weather.wav']
resp = pt_engine.infer([{'messages': messages, 'audios': audios}], request_config=request_config)
response = resp[0].choices[0].message.content
messages += [{'role': 'assistant', 'content': response}]
logger.info(f'model: {pt_engine.model_info.model_name}, messages: {messages}')
Expand All @@ -43,10 +44,17 @@ def test_xcomposer2d5_ol():
_infer_model(pt_engine)


def test_step_audio_chat():
pt_engine = PtEngine('stepfun-ai/Step-Audio-Chat')
response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': '<audio>'}])
assert response == ('是的呢,今天天气晴朗,阳光明媚,微风和煦,非常适合外出活动。天空湛蓝,白云朵朵,让人心情愉悦。希望你能好好享受这美好的一天!')


if __name__ == '__main__':
from swift.llm import PtEngine, RequestConfig, get_template
from swift.utils import get_logger, seed_everything
logger = get_logger()
# test_qwen_audio()
# test_qwen2_audio()
test_xcomposer2d5_ol()
# test_xcomposer2d5_ol()
test_step_audio_chat()
Loading