-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EpochIterator
, TensorFlowTrainer
: refactoring and fixes
#20396
Conversation
Tensorflow trainer's train and test functions refactoring Fix keras-team#20394 Fix keras-team#20390 Fix keras-team#20344
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is neat -- thanks for the PR.
keras/src/trainers/epoch_iterator.py
Outdated
@property | ||
def num_batches(self): | ||
if self.steps_per_epoch: | ||
return self.steps_per_epoch | ||
# Either copied from the data_adapter, or | ||
# inferred at the end of an iteration. | ||
return self._num_batches | ||
|
||
|
||
class EpochIterator(_EpochIterator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Couldn't the two classes be merged into one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I merged them.
keras/src/trainers/epoch_iterator.py
Outdated
class EpochIterator(_EpochIterator): | ||
def __next__(self): | ||
buffer = [] | ||
step, iterator = super().__next__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just need to call next(self._epoch_iterator)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed
@@ -75,55 +75,86 @@ def __init__( | |||
def _get_iterator(self): | |||
return self.data_adapter.get_numpy_iterator() | |||
|
|||
def enumerate_epoch(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method must be kept for backwards compatibility (for users who have a custom fit())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's back.
It's fairly confusing to have backend-specific iterators subclass |
Restored `enumerate_epoch`
I removed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for the contribution!
Fixes #20394
Fixes #20390
Fixes #20344
TensorFlowTrainer
:iterator.get_next_as_optional()
, which enablessteps_per_execution
to exceed the previous limit of 32 while preventingtf.errors.OutOfRangeError
EpochIterator
:steps_per_epoch
andsteps_per_execution