Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds TQDMProgressBar in callbacks #610

Merged
merged 10 commits into from
Oct 30, 2019
Prev Previous commit
Next Next commit
fix styling
  • Loading branch information
shun-lin committed Oct 23, 2019
commit 2b48c68d008ca1158d702f4337967ecc01a8ff3c
27 changes: 13 additions & 14 deletions tensorflow_addons/callbacks/tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import time
from collections import defaultdict

import numpy as np
import six
from tensorflow.keras.callbacks import Callback
from tensorflow_addons.utils import keras_utils
from tqdm.auto import tqdm
Expand Down Expand Up @@ -48,16 +46,17 @@ class TQDMProgressBar(Callback):
show_overall_progress (bool): False to hide overall progress bar
"""

def __init__(
self,
metrics_separator=" - ",
overall_bar_format='{l_bar}{bar} {n_fmt}/{total_fmt} ETA: {remaining}s, {rate_fmt}{postfix}',
epoch_bar_format='{n_fmt}/{total_fmt}{bar} ETA: {remaining}s - {desc}',
update_per_second=10,
leave_epoch_progress=True,
leave_overall_progress=True,
show_epoch_progress=True,
show_overall_progress=True):
def __init__(self,
metrics_separator=" - ",
overall_bar_format='{l_bar}{bar} {n_fmt}/{total_fmt} ETA: '
'{remaining}s, {rate_fmt}{postfix}',
epoch_bar_format='{n_fmt}/{total_fmt}{bar} ETA: '
'{remaining}s - {desc}',
update_per_second=10,
leave_epoch_progress=True,
leave_overall_progress=True,
show_epoch_progress=True,
show_overall_progress=True):

self.metrics_separator = metrics_separator
self.overall_bar_format = overall_bar_format
Expand Down Expand Up @@ -130,8 +129,8 @@ def on_epoch_end(self, epoch, logs={}):
self.epoch_progress_tqdm.mininterval = 0

# update the rest of the steps in epoch progress bar
self.epoch_progress_tqdm.update(
self.total_steps - self.epoch_progress_tqdm.n)
self.epoch_progress_tqdm.update(self.total_steps -
self.epoch_progress_tqdm.n)
self.epoch_progress_tqdm.close()

if self.show_overall_progress:
Expand Down