diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index bddfd3c5d..d49650ded 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -227,13 +227,11 @@ def _replace_image_tags(inputs: StdTemplateInputs): def _replace_start_image_tags(inputs: StdTemplateInputs): # compat generate_mode = False - for message in inputs.messages: - content = message['content'] - if not isinstance(content, str): - continue - if content.strip().endswith(''): - generate_mode = True - message['content'] = re.sub('', '', content).strip() # remove the + message = inputs.messages[-1] + content = message['content'] + if message['role'] == 'user' and content.endswith(''): + generate_mode = True + message['content'] = message['content'][:-len('')] # remove the inputs.generate_mode = generate_mode def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: diff --git a/tests/test_align/test_template/test_gene.py b/tests/test_align/test_template/test_gene.py index 012eb1674..9e985889b 100644 --- a/tests/test_align/test_template/test_gene.py +++ b/tests/test_align/test_template/test_gene.py @@ -8,10 +8,7 @@ def test_deepseek_janus_pro_gene(): from swift.llm import infer_main, InferArguments - args = InferArguments( - # model='deepseek-ai/Janus-Pro-1B', - model='/mnt/nas1/.cache/modelscope/hub/deepseek-ai/Janus-Pro-1B', - infer_backend='pt') + args = InferArguments(model='deepseek-ai/Janus-Pro-1B', infer_backend='pt') infer_main(args)