Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove graph breaks for torch.compile() in padding free branch in DataCollatorForCompletionOnlyLM #2158

Merged
merged 29 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4472501
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
6cfa171
fix: formatting
Abhishek-TAMU Oct 2, 2024
a821ce0
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
fb669b6
fix: formatting
Abhishek-TAMU Oct 2, 2024
f4b1955
Merge branch 'huggingface:main' into collator_batch
Abhishek-TAMU Oct 14, 2024
1b7c060
Merge branch 'collator_batch' of github.com:Abhishek-TAMU/trl into co…
Abhishek-TAMU Oct 21, 2024
c3578f8
Merge branch 'main' into collator_batch
Abhishek-TAMU Oct 21, 2024
e83fc8a
fix: max_length_k to int
Abhishek-TAMU Oct 21, 2024
68554b1
fix:Added comments
Abhishek-TAMU Oct 21, 2024
2a7dd47
Merge remote-tracking branch 'trl/main' into collator_batch
Abhishek-TAMU Oct 30, 2024
b0a52e2
test cases
Abhishek-TAMU Oct 30, 2024
054a6ef
test cases
Abhishek-TAMU Oct 30, 2024
376ad21
test cases
Abhishek-TAMU Oct 30, 2024
9a08ea3
Merge remote-tracking branch 'trl/main' into collator_batch
Abhishek-TAMU Nov 12, 2024
a97045b
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
f31a780
fix: formatting
Abhishek-TAMU Oct 2, 2024
29ba8a3
feat: Add info to batch in DataCollatorForCompletionOnlyLM
Abhishek-TAMU Oct 2, 2024
d1441e1
test cases
Abhishek-TAMU Oct 30, 2024
d55a6e2
test cases
Abhishek-TAMU Oct 30, 2024
7dccc2d
test cases
Abhishek-TAMU Oct 30, 2024
5e5224e
unit test changes
Abhishek-TAMU Nov 12, 2024
1b434b0
unit test changes
Abhishek-TAMU Nov 12, 2024
ef1e304
Merge remote-tracking branch 'trl/main' into collator_batch
Abhishek-TAMU Nov 18, 2024
77894b1
style
qgallouedec Nov 19, 2024
911f60c
Merge branch 'main' into collator_batch
qgallouedec Nov 19, 2024
979f9f0
Merge branch 'main' into collator_batch
qgallouedec Dec 18, 2024
cebf936
Merge branch 'main' into collator_batch
qgallouedec Jan 6, 2025
ca8e153
add test
qgallouedec Jan 6, 2025
8c27e16
remove test
qgallouedec Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion tests/test_data_collator_completion_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_padding_free(self):
inst1 = "### System: You are a helpful assistant.\n\n### User: How much is 2+2?\n\n### Assistant: 2+2 equals 4"
inst2 = "### System: You are a honest and helpful assistant.\n\n### User: What is the answer of 22x22?\n\n### Assistant: 22x22 equals 484"

response_template = "\n### Assistant:"
response_template = "\n\n### Assistant:"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise the template isn't found (\n\n is jointly tokenized)

collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
collator_paddingfree = DataCollatorForCompletionOnlyLM(
response_template, tokenizer=tokenizer, padding_free=True
Expand Down Expand Up @@ -143,3 +143,21 @@ def test_padding_free(self):
self.assertTrue((input_ids_remove_pad == batch_paddingfree["input_ids"]).all())
self.assertTrue((expected_position_ids == batch_paddingfree["position_ids"]).all())
self.assertTrue((expected_labels == batch_paddingfree["labels"]).all())

def test_data_collator_for_completion_only_lm(self):
# The tokenizer isn't use but the collator needs it to be provided.
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")

collator = DataCollatorForCompletionOnlyLM(tokenizer.decode(9999), tokenizer=tokenizer, padding_free=True)

tokenized_instruction = [
{"input_ids": [1, 2, 3, 9999, 4, 5], "attention_mask": [1, 1, 1, 1, 1, 1]},
{"input_ids": [6, 7, 8, 9, 9999, 10, 11], "attention_mask": [1, 1, 1, 1, 1, 1, 1]},
]
batch = collator(tokenized_instruction)

self.assertEqual(batch["position_ids"].tolist(), [[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6]]) # flat pos ids
self.assertEqual(batch["cu_seq_lens_q"].tolist(), [0, 6, 13]) # start idx of each seq + total number of tokens
self.assertEqual(batch["cu_seq_lens_k"].tolist(), [0, 6, 13]) # idem
self.assertEqual(batch["max_length_k"], 7) # max length in batch, here 7 (second sequence)
self.assertEqual(batch["max_length_q"], 7) # idem
19 changes: 19 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,25 @@ def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> d
batch["labels"] = batch["labels"][attn_mask.bool()].unsqueeze(0)
batch["labels"][batch["position_ids"] == 0] = self.ignore_index

# Calculate cumulative sequence lengths for queries and keys to prevent graph breaks during further computations.
flattened_position_ids = batch["position_ids"].flatten()
indices_q = torch.arange(
flattened_position_ids.size(0), device=flattened_position_ids.device, dtype=torch.int32
)
batch["cu_seq_lens_q"] = torch.cat(
(
indices_q[flattened_position_ids == 0],
torch.tensor(
flattened_position_ids.size(), device=flattened_position_ids.device, dtype=torch.int32
),
)
)
batch["cu_seq_lens_k"] = batch["cu_seq_lens_q"]

# Determine maximum sequence lengths to prevent graph breaks during further computations.
batch["max_length_k"] = flattened_position_ids.max().item() + 1
batch["max_length_q"] = batch["max_length_k"]

return batch


Expand Down
Loading