Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

💬 Add maybe_convert_to_chatml map for conversational datasets in SFT #2862

Merged
merged 8 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/data_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

[[autodoc]] maybe_apply_chat_template

## maybe_convert_to_chatml

[[autodoc]] maybe_convert_to_chatml

## extract_prompt

[[autodoc]] extract_prompt
Expand Down
46 changes: 46 additions & 0 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
extract_prompt,
is_conversational,
maybe_apply_chat_template,
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_examples,
Expand Down Expand Up @@ -435,6 +436,51 @@ def test_pack_with_dataset(self):
self.assertEqual(dataset.to_dict(), expected_output)


class TestMaybeConvertToChatML(unittest.TestCase):
def test_with_conversations_key(self):
# Particular case where the key is "conversations": we rename it to "messages"
example = {
"conversations": [
{"from": "user", "value": "What color is the sky?"},
{"from": "assistant", "value": "It is blue."},
]
}
expected_output = {
"messages": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
}
self.assertEqual(maybe_convert_to_chatml(example), expected_output)

def test_without_conversations_key(self):
# Same as before, but we don't rename the keys
example = {
"prompt": [{"from": "user", "value": "What color is the sky?"}],
"completion": [{"from": "assistant", "value": "It is blue."}],
}
expected_output = {
"prompt": [{"role": "user", "content": "What color is the sky?"}],
"completion": [{"role": "assistant", "content": "It is blue."}],
}
self.assertEqual(maybe_convert_to_chatml(example), expected_output)

def test_not_conversional(self):
# When not needed, the example should remain unchanged
example = {"text": "The sky is blue."}
self.assertEqual(maybe_convert_to_chatml(example), example)

def test_already_chatml(self):
# When the example is already in ChatML format, it should remain unchanged
example = {
"messages": [
{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."},
]
}
self.assertEqual(maybe_convert_to_chatml(example), example)


# Run the tests
if __name__ == "__main__":
unittest.main()
30 changes: 30 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,3 +1370,33 @@ def test_train_peft_model(self):
"base_layer" not in n
): # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

def test_train_with_non_chatml_conversational_data(self):
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")

# Rename role/content to from/value to ensure SFT works with non-chatML conversational data
def rename_fields(example: list[dict]):
return {"conversations": [{"from": m["role"], "value": m["content"]} for m in example["messages"]]}

dataset = dataset.map(rename_fields, remove_columns="messages")

with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, report_to="none")
trainer = SFTTrainer(args=training_args, model=model, train_dataset=dataset)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
2 changes: 2 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"extract_prompt",
"is_conversational",
"maybe_apply_chat_template",
"maybe_convert_to_chatml",
"maybe_extract_prompt",
"maybe_unpair_preference_dataset",
"pack_examples",
Expand Down Expand Up @@ -126,6 +127,7 @@
extract_prompt,
is_conversational,
maybe_apply_chat_template,
maybe_convert_to_chatml,
maybe_extract_prompt,
maybe_unpair_preference_dataset,
pack_examples,
Expand Down
65 changes: 58 additions & 7 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def is_conversational(example: dict[str, Any]) -> bool:
dataset type.

Returns:
`bool`: `True` if the data is in a conversational format, `False` otherwise.
`bool`:
`True` if the data is in a conversational format, `False` otherwise.

Examples:

Expand Down Expand Up @@ -185,20 +186,21 @@ def maybe_apply_chat_template(
For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
messages, where each message is a dictionary with keys `"role"` and `"content"`.
tokenizer (`PreTrainedTokenizer`):
The tokenizer to apply the chat template with.
Tokenizer to apply the chat template with.
tools (`list[Union[dict, Callable]]` or `None`, *optional*, defaults to `None`):
A list of tools (callable functions) that will be accessible to the model.
If the template does not support function calling, this argument will have no effect

Returns:
`dict[str, str]`: The formatted example with the chat template applied.
`dict[str, str]`:
Formatted example with the chat template applied.

Notes:
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced by
`"text"`.
- This function does not alter the keys, except for Language modeling dataset, where `"messages"` is replaced
by `"text"`.

- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt. Else,
if the last role is `"assistant"`, the final message is continued.
- In case of prompt-only data, if the last role is `"user"`, the generation prompt is added to the prompt.
Else, if the last role is `"assistant"`, the final message is continued.

Example:

Expand Down Expand Up @@ -462,3 +464,52 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
# Split the values into chunks of size seq_length
examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()}
return examples


def maybe_convert_to_chatml(example: dict[str, list]) -> dict[str, list]:
"""
Convert a dataset entry from NAME_TO_FIND format to ChatML format.

This function modifies conversational data to align with OpenAI's ChatML format:
- Replaces the key `"from"` with `"role"` in message dictionaries.
- Replaces the key `"value"` with `"content"` in message dictionaries.
- Renames `"conversations"` to `"messages"` for consistency with ChatML.

Args:
example (`dict[str, list]`):
A single data entry containing a list of messages.

Returns:
`dict[str, list]`:
Example reformatted to ChatML style.

Example:
```python
>>> from trl import maybe_convert_to_chatml
>>> example = {
... "conversations": [
... {"from": "user", "value": "What color is the sky?"},
... {"from": "assistant", "value": "It is blue."}
... ]
... }
>>> maybe_convert_to_chatml(example)
{'messages': [{'role': 'user', 'content': 'What color is the sky?'},
{'role': 'assistant', 'content': 'It is blue.'}]}
```
"""
# List of possible keys containing message lists
for key in ["prompt", "completion", "chosen", "rejected", "messages", "conversations"]:
if key in example and isinstance(example[key], list):
messages = example[key]
for message in messages:
if isinstance(message, dict):
if "from" in message:
message["role"] = message.pop("from")
if "value" in message:
message["content"] = message.pop("value")

# Rename "conversations" to "messages"
if "conversations" in example:
example["messages"] = example.pop("conversations")

return example
11 changes: 10 additions & 1 deletion trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from transformers.utils import is_liger_kernel_available, is_peft_available
from transformers.utils.deprecation import deprecate_kwarg

from ..data_utils import is_conversational, maybe_apply_chat_template, pack_examples
from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples
from .sft_config import SFTConfig
from .utils import (
ConstantLengthDataset,
Expand Down Expand Up @@ -395,6 +395,15 @@ def concat_prompt_completion(example):

dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])

# Convert the dataset to ChatML if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Converting {dataset_name} dataset to ChatML"
dataset = dataset.map(
maybe_convert_to_chatml,
remove_columns="conversations" if "conversations" in dataset.column_names else None,
**map_kwargs,
)

# Apply the chat template if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
Expand Down