diff --git "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md" "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md"
index 06c9e267f..5ef2ce6cc 100644
--- "a/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md"
+++ "b/docs/source/Instruction/\346\224\257\346\214\201\347\232\204\346\250\241\345\236\213\345\222\214\346\225\260\346\215\256\351\233\206.md"
@@ -657,6 +657,7 @@
|[BAAI/Emu3-Gen](https://modelscope.cn/models/BAAI/Emu3-Gen)|emu3_gen|emu3_gen|-|t2i|[BAAI/Emu3-Gen](https://huggingface.co/BAAI/Emu3-Gen)|
|[BAAI/Emu3-Chat](https://modelscope.cn/models/BAAI/Emu3-Chat)|emu3_chat|emu3_chat|transformers>=4.44.0|vision|[BAAI/Emu3-Chat](https://huggingface.co/BAAI/Emu3-Chat)|
|[stepfun-ai/GOT-OCR2_0](https://modelscope.cn/models/stepfun-ai/GOT-OCR2_0)|got_ocr2|got_ocr2|-|vision|[stepfun-ai/GOT-OCR2_0](https://huggingface.co/stepfun-ai/GOT-OCR2_0)|
+|[stepfun-ai/GOT-OCR-2.0-hf](https://modelscope.cn/models/stepfun-ai/GOT-OCR-2.0-hf)|got_ocr2_hf|got_ocr2_hf|-|vision|[stepfun-ai/GOT-OCR-2.0-hf](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf)|
|[stepfun-ai/Step-Audio-Chat](https://modelscope.cn/models/stepfun-ai/Step-Audio-Chat)|step_audio|step_audio|funasr, sox, conformer, openai-whisper, librosa|audio|[stepfun-ai/Step-Audio-Chat](https://huggingface.co/stepfun-ai/Step-Audio-Chat)|
|[LLM-Research/Phi-3-vision-128k-instruct](https://modelscope.cn/models/LLM-Research/Phi-3-vision-128k-instruct)|phi3_vision|phi3_vision|transformers>=4.36|vision|[microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)|
|[LLM-Research/Phi-3.5-vision-instruct](https://modelscope.cn/models/LLM-Research/Phi-3.5-vision-instruct)|phi3_vision|phi3_vision|transformers>=4.36|vision|[microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)|
diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md
index 636a14461..63da4467b 100644
--- a/docs/source_en/Instruction/Supported-models-and-datasets.md
+++ b/docs/source_en/Instruction/Supported-models-and-datasets.md
@@ -657,6 +657,7 @@ The table below introduces the models integrated with ms-swift:
|[BAAI/Emu3-Gen](https://modelscope.cn/models/BAAI/Emu3-Gen)|emu3_gen|emu3_gen|-|t2i|[BAAI/Emu3-Gen](https://huggingface.co/BAAI/Emu3-Gen)|
|[BAAI/Emu3-Chat](https://modelscope.cn/models/BAAI/Emu3-Chat)|emu3_chat|emu3_chat|transformers>=4.44.0|vision|[BAAI/Emu3-Chat](https://huggingface.co/BAAI/Emu3-Chat)|
|[stepfun-ai/GOT-OCR2_0](https://modelscope.cn/models/stepfun-ai/GOT-OCR2_0)|got_ocr2|got_ocr2|-|vision|[stepfun-ai/GOT-OCR2_0](https://huggingface.co/stepfun-ai/GOT-OCR2_0)|
+|[stepfun-ai/GOT-OCR-2.0-hf](https://modelscope.cn/models/stepfun-ai/GOT-OCR-2.0-hf)|got_ocr2_hf|got_ocr2_hf|-|vision|[stepfun-ai/GOT-OCR-2.0-hf](https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf)|
|[stepfun-ai/Step-Audio-Chat](https://modelscope.cn/models/stepfun-ai/Step-Audio-Chat)|step_audio|step_audio|funasr, sox, conformer, openai-whisper, librosa|audio|[stepfun-ai/Step-Audio-Chat](https://huggingface.co/stepfun-ai/Step-Audio-Chat)|
|[LLM-Research/Phi-3-vision-128k-instruct](https://modelscope.cn/models/LLM-Research/Phi-3-vision-128k-instruct)|phi3_vision|phi3_vision|transformers>=4.36|vision|[microsoft/Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)|
|[LLM-Research/Phi-3.5-vision-instruct](https://modelscope.cn/models/LLM-Research/Phi-3.5-vision-instruct)|phi3_vision|phi3_vision|transformers>=4.36|vision|[microsoft/Phi-3.5-vision-instruct](https://huggingface.co/microsoft/Phi-3.5-vision-instruct)|
diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py
index 10d04b55d..c2d8d80b4 100644
--- a/swift/llm/model/constant.py
+++ b/swift/llm/model/constant.py
@@ -192,6 +192,7 @@ class MLLMModelType:
emu3_gen = 'emu3_gen'
emu3_chat = 'emu3_chat'
got_ocr2 = 'got_ocr2'
+ got_ocr2_hf = 'got_ocr2_hf'
step_audio = 'step_audio'
phi3_vision = 'phi3_vision'
diff --git a/swift/llm/model/model/stepfun.py b/swift/llm/model/model/stepfun.py
index 2a256b9fe..c78d7d4da 100644
--- a/swift/llm/model/model/stepfun.py
+++ b/swift/llm/model/model/stepfun.py
@@ -8,7 +8,8 @@
from swift.llm import TemplateType
from ..constant import MLLMModelType
from ..model_arch import ModelArch
-from ..register import Model, ModelGroup, ModelMeta, get_model_tokenizer_with_flash_attn, register_model
+from ..register import (Model, ModelGroup, ModelMeta, get_model_tokenizer_multimodal,
+ get_model_tokenizer_with_flash_attn, register_model)
from ..utils import git_clone_github, safe_snapshot_download
@@ -32,6 +33,27 @@ def get_model_tokenizer_got_ocr2(*args, **kwargs):
tags=['vision']))
+def get_model_tokenizer_got_ocr2_hf(model_dir, *args, **kwargs):
+ from transformers.models.got_ocr2 import GotOcr2ForConditionalGeneration
+ GotOcr2ForConditionalGeneration._no_split_modules.append('GotOcr2VisionLayer')
+ model, processor = get_model_tokenizer_multimodal(model_dir, *args, **kwargs)
+ return model, processor
+
+
+register_model(
+ ModelMeta(
+ MLLMModelType.got_ocr2_hf, [
+ ModelGroup([
+ Model('stepfun-ai/GOT-OCR-2.0-hf', 'stepfun-ai/GOT-OCR-2.0-hf'),
+ ]),
+ ],
+ TemplateType.got_ocr2_hf,
+ get_model_tokenizer_got_ocr2_hf,
+ model_arch=ModelArch.got_ocr2_hf,
+ architectures=['GOTQwenForCausalLM'],
+ tags=['vision']))
+
+
def get_model_tokenizer_step_audio(*args, **kwargs):
local_repo_path = kwargs.get('local_repo_path')
if not local_repo_path:
diff --git a/swift/llm/model/model_arch.py b/swift/llm/model/model_arch.py
index f86547ded..f1f853d40 100644
--- a/swift/llm/model/model_arch.py
+++ b/swift/llm/model/model_arch.py
@@ -56,6 +56,7 @@ class MLLMModelArch:
idefics3 = 'idefics3'
got_ocr2 = 'got_ocr2'
+ got_ocr2_hf = 'got_ocr2_hf'
ovis1_6 = 'ovis1_6'
molmo = 'molmo'
diff --git a/swift/llm/template/constant.py b/swift/llm/template/constant.py
index fb3c29462..07186c180 100644
--- a/swift/llm/template/constant.py
+++ b/swift/llm/template/constant.py
@@ -152,6 +152,7 @@ class MLLMTemplateType:
emu3_gen = 'emu3_gen'
got_ocr2 = 'got_ocr2'
+ got_ocr2_hf = 'got_ocr2_hf'
step_audio = 'step_audio'
idefics3 = 'idefics3'
diff --git a/swift/llm/template/template/stepfun.py b/swift/llm/template/template/stepfun.py
index e03e4f598..a06e7d4ad 100644
--- a/swift/llm/template/template/stepfun.py
+++ b/swift/llm/template/template/stepfun.py
@@ -69,6 +69,42 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
))
+class GOT_OCR2HfTemplate(Template):
+
+ def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
+ inputs: StdTemplateInputs) -> List[Context]:
+ # 'OCR: '
+ # 'OCR with format: '
+ assert media_type == 'image'
+ return ['
' + '' * 256 + '\n']
+
+ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: # 暂时照抄上面
+ encoded = super()._encode(inputs)
+ images = inputs.images
+ if images:
+ encoded['images'] = images
+ return encoded
+
+ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
+ res = super()._data_collator(batch, padding_to=padding_to)
+ images = self.gather_list(batch, 'images')
+ _inputs = self.processor(images, return_tensors='pt')
+ _inputs.pop('input_ids') # this does not contain the response, so cannot be used when training
+ _inputs.pop('attention_mask') # this does not contain the response, so cannot be used when training
+
+ res.update(_inputs.data)
+ return res
+
+
+register_template(
+ QwenTemplateMeta(
+ MLLMTemplateType.got_ocr2_hf,
+ default_system=' You should follow the instructions carefully and explain your answers in detail.',
+ template_cls=GOT_OCR2HfTemplate,
+ placeholder_tokens=[''],
+ ))
+
+
class StepAudioTemplate(Template):
use_model = True
diff --git a/tests/test_align/test_template/test_vision.py b/tests/test_align/test_template/test_vision.py
index dc4a3de86..59f401782 100644
--- a/tests/test_align/test_template/test_vision.py
+++ b/tests/test_align/test_template/test_vision.py
@@ -157,6 +157,21 @@ def test_got_ocr():
images=['https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/ocr.png'])
+def test_got_ocr_hf():
+ pt_engine = PtEngine('stepfun-ai/GOT-OCR-2.0-hf')
+ response = _infer_model(
+ pt_engine,
+ messages=[{
+ 'role': 'user',
+ 'content': 'OCR: '
+ }],
+ images=['https://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/ocr.png'])
+ assert response[:200] == ('简介 SWIFT支持250+LLM和35+MLLM(多模态大模型)的训练、推理、 评测和部署。开发者可以直接将'
+ '我们的框架应用到自己的Research和 生产环境中,实现模型训练评测到应用的完整链路。我们除支持了 PEFT提供的轻量训练方案外'
+ ',也提供了一个完整的Adapters库以支持 最新的训练技术,如NEFTune、LoRA+、LLaMA-PRO等,这个适配器 库可以脱离训练脚本'
+ '直接使用在自己的')
+
+
def test_llama_vision():
pt_engine = PtEngine('LLM-Research/Llama-3.2-11B-Vision-Instruct')
response = _infer_model(pt_engine)
@@ -465,6 +480,7 @@ def test_ui_tars():
# test_llava_onevision_hf()
# test_minicpmv()
# test_got_ocr()
+ test_got_ocr_hf()
# test_paligemma()
# test_paligemma2()
# test_pixtral()