From 8e9f6b88d03036083b2a57601ee242907724e667 Mon Sep 17 00:00:00 2001 From: renjie130 <35288954+renjie130@users.noreply.github.com> Date: Fri, 27 Dec 2024 11:30:04 +0800 Subject: [PATCH] Update workflow.py --- src/llamafactory/train/sft/workflow.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index b290af0d91..190fe3f012 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, List, Optional +from transformers import DataCollatorForLanguageModeling from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger @@ -26,6 +27,7 @@ from ..trainer_utils import create_modelcard_and_push from .metric import ComputeAccuracy, ComputeSimilarity, eval_logit_processor from .trainer import CustomSeq2SeqTrainer +from .trainer import CustomTrainer if TYPE_CHECKING: @@ -54,16 +56,7 @@ def run_sft( if getattr(model, "is_quantized", False) and not training_args.do_train: setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction - data_collator = SFTDataCollatorWith4DAttentionMask( - template=template, - model=model if not training_args.predict_with_generate else None, - pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention - label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, - block_diag_attn=model_args.block_diag_attn, - attn_implementation=getattr(model.config, "_attn_implementation", None), - compute_dtype=model_args.compute_dtype, - **tokenizer_module, - ) + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len @@ -79,7 +72,7 @@ def run_sft( metric_module["preprocess_logits_for_metrics"] = eval_logit_processor # Initialize our Trainer - trainer = CustomSeq2SeqTrainer( + trainer = CustomTrainer( model=model, args=training_args, finetuning_args=finetuning_args, @@ -87,7 +80,6 @@ def run_sft( callbacks=callbacks, **dataset_module, **tokenizer_module, - **metric_module, ) # Keyword arguments for `model.generate`