-
Notifications
You must be signed in to change notification settings - Fork 175
Conversation
And I wonder why we add 2 special tokens when we start (generate)decode, In here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -390,7 +390,7 @@ def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Di | |||
predicted_tokens = [None] * predictions.shape[0] | |||
for i in range(predictions.shape[0]): | |||
predicted_tokens[i] = self._indexer.indices_to_tokens( | |||
{"token_ids": predictions[0].tolist()}, self.vocab | |||
{"token_ids": predictions[i].tolist()}, self.vocab |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, what a bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, I'm surprised I didn't catch it earlier. Perhaps I was only using batch size 1 when making predictions and missed it.
@wlhgtc, I suspect it's there because huggingface does it, and to make the weights work we have to do whatever huggingface does. But it's not in the original paper, and I can't find that part in the huggingface code either. @Tobias-Rohde, do you know why? |
@wlhgtc, does something bad happen when you take away the second special token? If not, I'd be in favor of removing it and retraining the model. |
I will merge this though and we can take the special token thing in a separate PR. |
Yes, it's for consistency with hugginface, but that's also what they did in BART, even though it might not be mentioned explicitly. See here for whole thread on this involving the BART authors: And here for the part where huggingface adds the token: |
Sorry for my lately reply. |
@Tobias-Rohde |
@Tobias-Rohde, that code only adds one token though. |
@dirkgr @wlhgtc See this part of huggingface's beam search: which calls this when using Bart: There it checks if the current generated length is 1, i.e EOS was generated (which happens in the function I linked earlier) and if |
What needs to be done then? Should we take out the second token? What's the point of force-generating a token like that? |
Originally I implemented it in this way to try to reproduce the exact behavior of what huggingface was doing while trying to figure out why I was getting worse performance on CNN/DM. I think it would be fine to remove it, especially since |
The index should be
i
rather than0
inmake_output_human_readable
method.