-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix bug in batched_forward_pass #144
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow thanks a lot for fixing the bug!
I am ok with this fix in the principle that it is a safety checker to make sure all the returned tensors will be on the correct device (regardless where the dataloader will send the device)
Let's run the tests and see!
Wdyt @lvwerra ?
@ArvinZhuang can you run the styling and quality checks? (make style && make quality
) Thanks!
|
||
attention_mask = [torch.ones_like(r) for r in responses] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum why this has been removed? 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because seems attention_mask variable is never used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, can you run the styling checks so that the testing suite will be executed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just did
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing! Could you revert your changes in all the scripts inside examples/
after that we should be good to merge!
Hi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating! 🚀
Also I notice this |
Had just prepared a PR to fix this 😄 thanks @ArvinZhuang |
@Rebecca-Qian You are welcome :) Im glad to contribute as well |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
fix the bug that will cause the devices not match issue in the
batched_forward_pass
method.The reason:
self. data_collator
returned tensors will be on CPU, thus the laterself.model(**input_kwargs)
will give error as model is on GPU.The proposed solution:
do
.to(self.accelerator.device)
afterself.data_collator
My transformers version: 4.26.1