Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KTO/BCO Trainer] add bos_token_id only if it exists #2279

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,17 +223,20 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, **
)

# add BOS, which affects both prompt and the full completion
if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}prompt_input_ids"
]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"]
batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}completion_input_ids"
]
batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}completion_attention_mask"
]
if bos_token_id is not None:
if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}prompt_input_ids"
]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}prompt_attention_mask"
]
batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}completion_input_ids"
]
batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}completion_attention_mask"
]
# add EOS, which affects only the full completion
if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]:
batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [
Expand Down
25 changes: 14 additions & 11 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,20 @@ def _process_tokens(example: Dict[str, Any], model: "PreTrainedModel" = None, **
)

# add BOS, which affects both prompt and the full completion
if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}prompt_input_ids"
]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[f"{kwargs['prefix']}prompt_attention_mask"]
batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}completion_input_ids"
]
batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}completion_attention_mask"
]
if bos_token_id is not None:
if len(all_tokens["prompt_input_ids"]) == 0 or bos_token_id != all_tokens["prompt_input_ids"][0]:
batch[f"{kwargs['prefix']}prompt_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}prompt_input_ids"
]
batch[f"{kwargs['prefix']}prompt_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}prompt_attention_mask"
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just wondering if it should be:

if len(all_tokens["completion_input_ids"]) == 0 or bos_token_id != all_tokens["completion_input_ids"][0]:

before this

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I suspect there is an error here, because we shouldn't add bos to completion (expect for enc-dec). But in a near future, we will refactor as we did in #2209. Open to contributions btw.

I think we can keep it like this for the moment.

batch[f"{kwargs['prefix']}completion_input_ids"] = [bos_token_id] + batch[
f"{kwargs['prefix']}completion_input_ids"
]
batch[f"{kwargs['prefix']}completion_attention_mask"] = [1] + batch[
f"{kwargs['prefix']}completion_attention_mask"
]
# add EOS, which affects only the full completion
if len(all_tokens["answer_input_ids"]) == 0 or eos_token_id != all_tokens["answer_input_ids"][-1]:
batch[f"{kwargs['prefix']}completion_input_ids"] = batch[f"{kwargs['prefix']}completion_input_ids"] + [
Expand Down
Loading