Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EpochIterator, TensorFlowTrainer : refactoring and fixes #20396

Merged
merged 6 commits into from
Oct 22, 2024

Conversation

nicolaspi
Copy link
Contributor

@nicolaspi nicolaspi commented Oct 22, 2024

Fixes #20394
Fixes #20390
Fixes #20344

TensorFlowTrainer :

  • Added usage of iterator.get_next_as_optional(), which enables steps_per_execution to exceed the previous limit of 32 while preventing tf.errors.OutOfRangeError

EpochIterator :

  • Refactored to unify code base across different backends
  • Fixed edge cases involving various combinations of steps_per_epoch and steps_per_execution

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat -- thanks for the PR.

@property
def num_batches(self):
if self.steps_per_epoch:
return self.steps_per_epoch
# Either copied from the data_adapter, or
# inferred at the end of an iteration.
return self._num_batches


class EpochIterator(_EpochIterator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't the two classes be merged into one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I merged them.

class EpochIterator(_EpochIterator):
def __next__(self):
buffer = []
step, iterator = super().__next__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just need to call next(self._epoch_iterator)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed

@@ -75,55 +75,86 @@ def __init__(
def _get_iterator(self):
return self.data_adapter.get_numpy_iterator()

def enumerate_epoch(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method must be kept for backwards compatibility (for users who have a custom fit())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's back.

@fchollet
Copy link
Collaborator

It's fairly confusing to have backend-specific iterators subclass _EpochIterator -- much better to have a single EpochIterator class and have backend-specific classes override what they need to change.

Restored `enumerate_epoch`
@nicolaspi
Copy link
Contributor Author

It's fairly confusing to have backend-specific iterators subclass _EpochIterator -- much better to have a single EpochIterator class and have backend-specific classes override what they need to change.

I removed _EpochIterator.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the contribution!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 22, 2024
@fchollet fchollet merged commit 6b662d1 into keras-team:master Oct 22, 2024
9 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Oct 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
4 participants