-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix KerasMetricCallback prediction with generate() and inference of column names #15351
Conversation
Should now have very few differences with the PyTorch implementation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing!
# Welp, they're not the same length. Let's do some padding | ||
max_len = max([batch.shape[1] for batch in batches]) | ||
num_samples = sum([batch.shape[0] for batch in batches]) | ||
output = np.full_like( | ||
batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:]) | ||
) | ||
# i keeps track of which part of the concatenated array we're writing the next batch to | ||
i = 0 | ||
for batch in batches: | ||
output[i : i + len(batch), : batch.shape[1]] = batch | ||
i += len(batch) | ||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Co-authored-by: Sylvain Gugger <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
A number of fixes to the Keras metric callback based on testing with the notebooks. I think I've caught all the implementation differences, and in most cases we can now pass exactly the same
compute_metrics
function to this callback as toTrainer
.I'm still in the process of overhauling the notebooks, so I might need to make a few more tweaks to this PR if I encounter any other problems!