Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update workflow.py #1

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Update workflow.py #1

wants to merge 1 commit into from

Conversation

renjie130
Copy link
Owner

@renjie130 renjie130 commented Dec 27, 2024

User description

What does this PR do?

Fixes # (issue)

Before submitting


PR Type

Enhancement


Description

  • 引入了 DataCollatorForLanguageModeling 以简化数据整理过程。
  • CustomSeq2SeqTrainer 替换为 CustomTrainer,可能改进了训练逻辑。
  • 移除了 SFTDataCollatorWith4DAttentionMask,简化了代码结构。

Changes walkthrough 📝

Relevant files
增强
workflow.py
更新训练工作流                                                                                                   

src/llamafactory/train/sft/workflow.py

  • 引入了 DataCollatorForLanguageModeling 用于数据整理。
  • CustomSeq2SeqTrainer 替换为 CustomTrainer
  • 移除了 SFTDataCollatorWith4DAttentionMask 的使用。
  • +4/-12   

    💡 PR-Agent usage: Comment /help "your question" on any pull request to receive relevant information

    @renjie130
    Copy link
    Owner Author

    renjie130 commented Jan 6, 2025

    PR Reviewer Guide 🔍

    (Review updated until commit 8e9f6b8)

    Here are some key observations to aid the review process:

    ⏱️ Estimated effort to review: 2 🔵🔵⚪⚪⚪
    🏅 Score: 85
    🧪 PR contains tests
    🔒 Security concerns

    ⚡ Recommended focus areas for review

    潜在功能影响

    CustomSeq2SeqTrainer 替换为 CustomTrainer 可能会影响训练逻辑,需要验证是否所有功能都得到保留。

        trainer = CustomTrainer(
            model=model,
            args=training_args,
            finetuning_args=finetuning_args,
            data_collator=data_collator,
            callbacks=callbacks,
            **dataset_module,
            **tokenizer_module,
        )
    数据整理器变更

    使用 DataCollatorForLanguageModeling 替换 SFTDataCollatorWith4DAttentionMask 可能会影响数据处理逻辑,需要验证是否所有功能都得到保留。

        data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    @renjie130
    Copy link
    Owner Author

    PR Description updated to latest commit (8e9f6b8)

    @renjie130
    Copy link
    Owner Author

    renjie130 commented Jan 6, 2025

    PR Code Suggestions ✨

    Latest suggestions up to 8e9f6b8

    CategorySuggestion                                                                                                                                    Score
    关键错误
    确保数据整理器与训练器的需求兼容。

    确保 data_collator 的配置与 CustomTrainer 的需求一致,因为 DataCollatorForLanguageModeling
    可能不支持某些特定的训练需求。

    src/llamafactory/train/sft/workflow.py [59]

    -data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    +data_collator = SFTDataCollatorWith4DAttentionMask(...)
    Suggestion importance[1-10]: 8

    Why: 该建议指出了潜在的关键兼容性问题,确保数据整理器与训练器的需求一致非常重要。

    8

    Previous suggestions

    Suggestions up to commit 8e9f6b8
    CategorySuggestion                                                                                                                                    Score
    关键错误
    确保数据整理器与训练器的需求一致。

    确保 data_collator 的配置与 CustomTrainer 的需求一致,因为 DataCollatorForLanguageModeling 可能不包含
    SFTDataCollatorWith4DAttentionMask 的所有功能。

    src/llamafactory/train/sft/workflow.py [59]

    -data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    +data_collator = SFTDataCollatorWith4DAttentionMask(
    +    template=template,
    +    model=model if not training_args.predict_with_generate else None,
    +    pad_to_multiple_of=8 if training_args.do_train else None,
    +    label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
    +    block_diag_attn=model_args.block_diag_attn,
    +    attn_implementation=getattr(model.config, "_attn_implementation", None),
    +    compute_dtype=model_args.compute_dtype,
    +    **tokenizer_module,
    +)
    Suggestion importance[1-10]: 10

    Why: 这是一个关键错误,因为 DataCollatorForLanguageModeling 可能不包含 SFTDataCollatorWith4DAttentionMask 的所有功能,可能导致训练失败或结果不准确。

    10
    Suggestions up to commit 8e9f6b8
    CategorySuggestion                                                                                                                                    Score
    关键错误
    恢复使用 SFTDataCollatorWith4DAttentionMask 以确保数据整理器与模型和训练参数兼容。

    确保 data_collator 使用 SFTDataCollatorWith4DAttentionMask 而不是
    DataCollatorForLanguageModeling,以保持与模型和训练参数的兼容性。

    src/llamafactory/train/sft/workflow.py [59]

    -data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    +data_collator = SFTDataCollatorWith4DAttentionMask(
    +    template=template,
    +    model=model if not training_args.predict_with_generate else None,
    +    pad_to_multiple_of=8 if training_args.do_train else None,
    +    label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
    +    block_diag_attn=model_args.block_diag_attn,
    +    attn_implementation=getattr(model.config, "_attn_implementation", None),
    +    compute_dtype=model_args.compute_dtype,
    +    **tokenizer_module,
    +)
    Suggestion importance[1-10]: 10

    Why: 该建议解决了关键错误,确保数据整理器与模型和训练参数的兼容性,避免潜在的功能问题。

    10
    Suggestions up to commit 8e9f6b8
    CategorySuggestion                                                                                                                                    Score
    关键错误
    确保数据整理器的配置与训练器的需求一致。

    确保 data_collator 的配置与 CustomTrainer 的需求一致,避免潜在的训练或评估错误。

    src/llamafactory/train/sft/workflow.py [59]

    -data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    +data_collator = SFTDataCollatorWith4DAttentionMask(...)
    Suggestion importance[1-10]: 10

    Why: 该建议指出了关键错误,data_collator 的配置与 CustomTrainer 的需求不一致,可能导致训练或评估错误。

    10

    @renjie130
    Copy link
    Owner Author

    Persistent review updated to latest commit 8e9f6b8

    1 similar comment
    @renjie130
    Copy link
    Owner Author

    Persistent review updated to latest commit 8e9f6b8

    @renjie130
    Copy link
    Owner Author

    Failed to generate code suggestions for PR

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant