-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Support for Passing Pretokenized Datasets to TRL #166
Add Support for Passing Pretokenized Datasets to TRL #166
Conversation
) | ||
|
||
### Utils for custom masking / manipulating input / output strs, etc | ||
def combine_sequence(input_element: str, output_element: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
with my upcoming change to accept a template in API - which will be like verbalizer field. In future we ll need to apply template while combining sequence - which will be a minor addition, that way we can accept "input/output" + custom template. Need not worry about that now though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool sounds good - that is partially why I wrote things this way, even though the input/output are hardcoded in the thing calling this, so we can just pass everything through here as needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are the templates going to be jinja style like this? https://huggingface.co/docs/transformers/en/chat_templating
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
Signed-off-by: Alex-Brooks <[email protected]>
de61112
to
5520d92
Compare
data_kwargs = {} | ||
if isinstance(data_collator, DataCollatorForSeq2Seq): | ||
# HACK: This function is never called, but is needed to sidestep TRL's internal validation. | ||
data_kwargs["formatting_func"] = lambda x: x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was implementing a kind of silly workaround for TRL's validation logic with wrapping the tokenizer to make tokenize()
a noop when I realized that there is actually already logic for handling pretokenized data based on the columns of the passed dataset (ref). Currently, all we have to do to handle the tokenized data is pass a dummy formatting function, because _prepare_dataset
will inspect the dataset and return immediately without ever calling the value passed here.
There is additionally an extra argument, dataset_kwargs
, which is a dict that can pass skip_prepare_dataset
, which is bool-valued. I think the correct behavior is to just check that value in TRL and skip the validation for the packing=True
case, which is a small change that does not affect its API or change the supported data formats.
I have opened an issue / pull request in TRL to this effect: huggingface/trl#1673
If this change is merged, I will open another PR to remove this hack
1c22429
to
7fc6478
Compare
Signed-off-by: Alex-Brooks <[email protected]>
7fc6478
to
2b38589
Compare
@alex-jw-brooks Are we using huggingface/trl#1520 ? |
return input_element + output_element | ||
|
||
|
||
def preprocess_and_tokenize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think here we ll need to accept template as argument and use template to combine sequence
from tuning.config import configs | ||
|
||
|
||
def get_data_trainer_kwargs( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just wondering if this wrapper function is hiding too much from main code. Wondering if it would make sense to move this back to main code, esp once the hack is removed . as its basically 2 steps ->
get collator, get_formatted_dataset -> internally data formatting might happen in different ways and that can be combined to 1 function.
I am still thinking about it, but on initial thoughts I feel it might be good to know high level steps in train()
Closing this PR, as it was refactored and merged in #260 |
Description of the change
Related issue number
How to verify the PR
Was the PR tested