Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for Passing Pretokenized Datasets to TRL #166

Conversation

alex-jw-brooks
Copy link
Collaborator

@alex-jw-brooks alex-jw-brooks commented May 22, 2024

Description of the change

  • Adds support for pretokenized datasets.
  • Refactors logic for formatting the train / eval dataset + picking a collator into preprocessing utils

Related issue number

How to verify the PR

  • Run the unit tests and see all preprocessing tests are passing
  • Run an end to end tuning
export DATA_PATH=/home/SSO/us2j7257/fms-hf-tuning/tests/data/twitter_complaints_input_output.json
export CUDA_VISIBLE_DEVICES=0
export MODEL_PATH=TinyLlama/TinyLlama-1.1B-step-50K-105b 
export OUTPUT_PATH=out


python tuning/sft_trainer.py  --model_name_or_path $MODEL_PATH  --training_data_path $DATA_PATH  --output_dir $OUTPUT_PATH  --num_train_epochs 20  --per_device_train_batch_size 4 --per_device_eval_batch_size 4  --gradient_accumulation_steps 1  --evaluation_strategy "no"  --save_strategy "epoch"  --learning_rate 0.03  --weight_decay 0.  --warmup_ratio 0.03  --lr_scheduler_type "cosine"  --logging_steps 1  --include_tokens_per_second  --packing False  --use_flash_attn False  --tokenizer_name_or_path $MODEL_PATH --torch_dtype "float32" --peft_method "pt" --num_virtual_tokens 1500 --prompt_tuning_init_text "Classify if the tweet is a complaint or not:"
  • You can also inspect the output of the seq2seq collator from above ^ and verify that padding works correctly, and that the input ids / labels / attention mask are correctly manipulated.

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

@alex-jw-brooks alex-jw-brooks changed the title Formatting consolidation Add Support for Passing Pretokenized Datasets to TRL May 22, 2024
)

### Utils for custom masking / manipulating input / output strs, etc
def combine_sequence(input_element: str, output_element: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with my upcoming change to accept a template in API - which will be like verbalizer field. In future we ll need to apply template while combining sequence - which will be a minor addition, that way we can accept "input/output" + custom template. Need not worry about that now though

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool sounds good - that is partially why I wrote things this way, even though the input/output are hardcoded in the thing calling this, so we can just pass everything through here as needed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the templates going to be jinja style like this? https://huggingface.co/docs/transformers/en/chat_templating

@alex-jw-brooks alex-jw-brooks marked this pull request as ready for review May 28, 2024 05:14
@alex-jw-brooks alex-jw-brooks requested a review from anhuong as a code owner May 28, 2024 05:14
@alex-jw-brooks alex-jw-brooks force-pushed the formatting_consolidation branch from de61112 to 5520d92 Compare May 28, 2024 05:22
@alex-jw-brooks alex-jw-brooks mentioned this pull request May 28, 2024
data_kwargs = {}
if isinstance(data_collator, DataCollatorForSeq2Seq):
# HACK: This function is never called, but is needed to sidestep TRL's internal validation.
data_kwargs["formatting_func"] = lambda x: x
Copy link
Collaborator Author

@alex-jw-brooks alex-jw-brooks May 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was implementing a kind of silly workaround for TRL's validation logic with wrapping the tokenizer to make tokenize() a noop when I realized that there is actually already logic for handling pretokenized data based on the columns of the passed dataset (ref). Currently, all we have to do to handle the tokenized data is pass a dummy formatting function, because _prepare_dataset will inspect the dataset and return immediately without ever calling the value passed here.

There is additionally an extra argument, dataset_kwargs, which is a dict that can pass skip_prepare_dataset, which is bool-valued. I think the correct behavior is to just check that value in TRL and skip the validation for the packing=True case, which is a small change that does not affect its API or change the supported data formats.

I have opened an issue / pull request in TRL to this effect: huggingface/trl#1673

If this change is merged, I will open another PR to remove this hack

@alex-jw-brooks alex-jw-brooks force-pushed the formatting_consolidation branch from 1c22429 to 7fc6478 Compare May 28, 2024 05:49
@alex-jw-brooks alex-jw-brooks force-pushed the formatting_consolidation branch from 7fc6478 to 2b38589 Compare May 28, 2024 05:54
@ashokponkumar
Copy link
Collaborator

@alex-jw-brooks Are we using huggingface/trl#1520 ?

return input_element + output_element


def preprocess_and_tokenize(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think here we ll need to accept template as argument and use template to combine sequence

from tuning.config import configs


def get_data_trainer_kwargs(
Copy link
Collaborator

@Ssukriti Ssukriti Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just wondering if this wrapper function is hiding too much from main code. Wondering if it would make sense to move this back to main code, esp once the hack is removed . as its basically 2 steps ->
get collator, get_formatted_dataset -> internally data formatting might happen in different ways and that can be combined to 1 function.

I am still thinking about it, but on initial thoughts I feel it might be good to know high level steps in train()

@alex-jw-brooks
Copy link
Collaborator Author

Closing this PR, as it was refactored and merged in #260

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants