Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KerasMetricCallback prediction with generate() and inference of column names #15351

Merged
merged 6 commits into from
Jan 27, 2022

Conversation

Rocketknight1
Copy link
Member

A number of fixes to the Keras metric callback based on testing with the notebooks. I think I've caught all the implementation differences, and in most cases we can now pass exactly the same compute_metrics function to this callback as to Trainer.

I'm still in the process of overhauling the notebooks, so I might need to make a few more tweaks to this PR if I encounter any other problems!

Should now have very few differences with the PyTorch implementation
@HuggingFaceDocBuilder
Copy link

HuggingFaceDocBuilder commented Jan 26, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

Comment on lines +135 to +146
# Welp, they're not the same length. Let's do some padding
max_len = max([batch.shape[1] for batch in batches])
num_samples = sum([batch.shape[0] for batch in batches])
output = np.full_like(
batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
)
# i keeps track of which part of the concatenated array we're writing the next batch to
i = 0
for batch in batches:
output[i : i + len(batch), : batch.shape[1]] = batch
i += len(batch)
return output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@Rocketknight1 Rocketknight1 merged commit 6beae76 into master Jan 27, 2022
@Rocketknight1 Rocketknight1 deleted the keras_metric_callback_fixes branch January 27, 2022 14:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants