Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix label datatype in TF Trainer #9616

Merged
merged 2 commits into from
Jan 20, 2021
Merged

Fix label datatype in TF Trainer #9616

merged 2 commits into from
Jan 20, 2021

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Jan 15, 2021

What does this PR do?

This PR fixes the case where labels can be either a dict or a tf.Tensor when doing gradient accumulation.

@jplu jplu requested review from sgugger and LysandreJik January 15, 2021 10:03
@jplu jplu mentioned this pull request Jan 15, 2021
5 tasks
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks okay to me, but it looks increasingly clearer that we should have tests of the TFTrainer otherwise we are doing more harm than good by merging those kinds of PRs.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, LGTM!

@LysandreJik
Copy link
Member

I agree with Sylvain that while this is not tested, it's hard to recommend using it.

@jplu jplu merged commit 12f0d7e into huggingface:master Jan 20, 2021
@jplu jplu deleted the fix-trainer branch January 20, 2021 11:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants