diff --git a/docs/tutorials/tqdm_progress_bar.ipynb b/docs/tutorials/tqdm_progress_bar.ipynb
new file mode 100644
index 0000000000..88f26a069c
--- /dev/null
+++ b/docs/tutorials/tqdm_progress_bar.ipynb
@@ -0,0 +1,222 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Copyright 2019 The TensorFlow Authors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# TensorFlow Addons Callbacks: TQDM Progress Bar"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Overview\n",
+ "This notebook will demonstrate how to use TQDMCallback in TensorFlow Addons."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install -q tqdm>=4.36.1\n",
+ "\n",
+ "!pip install -q ipywidgets\n",
+ "!pip install -q --no-deps tensorflow-addons~=0.6\n",
+ "!jupyter nbextension enable --py widgetsnbextension --sys-prefix"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "try:\n",
+ " # %tensorflow_version only exists in Colab.\n",
+ " %tensorflow_version 2.x\n",
+ "except Exception:\n",
+ " pass"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "import tensorflow_addons as tfa\n",
+ "\n",
+ "import tensorflow.keras as keras\n",
+ "from tensorflow.keras.datasets import mnist\n",
+ "from tensorflow.keras.models import Sequential\n",
+ "from tensorflow.keras.layers import Dense, Dropout, Flatten"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Import and Normalize Data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# the data, split between train and test sets\n",
+ "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
+ "# normalize data\n",
+ "x_train, x_test = x_train / 255.0, x_test / 255.0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Build Simple MNIST CNN Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# build the model using the Sequential API\n",
+ "model = Sequential()\n",
+ "model.add(Flatten(input_shape=(28, 28)))\n",
+ "model.add(Dense(128, activation='relu'))\n",
+ "model.add(Dropout(0.2))\n",
+ "model.add(Dense(10, activation='softmax'))\n",
+ "\n",
+ "model.compile(optimizer='adam',\n",
+ " loss = 'sparse_categorical_crossentropy',\n",
+ " metrics=['accuracy'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Default TQDMCallback Usage"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "# initialize tqdm callback with default parameters\n",
+ "tqdm_callback = tfa.callbacks.TQDMProgressBar()\n",
+ "\n",
+ "# train the model with tqdm_callback\n",
+ "# make sure to set verbose = 0 to disable\n",
+ "# the default progress bar.\n",
+ "model.fit(x_train, y_train,\n",
+ " batch_size=64,\n",
+ " epochs=10,\n",
+ " verbose=0,\n",
+ " callbacks=[tqdm_callback],\n",
+ " validation_data=(x_test, y_test))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Below is the expected output when you run the cell above**\n",
+ "![TQDM Progress Bar Figure](https://raw.githubusercontent.com/tensorflow/addons/59961669a0e21eb4c045d4ad38d008a529d566c2/docs/tutorials/assets/tqdm_progress_bar_demo.png)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/tensorflow_addons/callbacks/BUILD b/tensorflow_addons/callbacks/BUILD
index e0388beaa0..c6bd1c573b 100644
--- a/tensorflow_addons/callbacks/BUILD
+++ b/tensorflow_addons/callbacks/BUILD
@@ -6,6 +6,7 @@ py_library(
name = "callbacks",
srcs = [
"__init__.py",
+ "tqdm_progress_bar.py",
],
srcs_version = "PY2AND3",
deps = [
diff --git a/tensorflow_addons/callbacks/README.md b/tensorflow_addons/callbacks/README.md
index b16710c96e..103e334b44 100644
--- a/tensorflow_addons/callbacks/README.md
+++ b/tensorflow_addons/callbacks/README.md
@@ -3,12 +3,12 @@
## Maintainers
| Submodule | Maintainers | Contact Info |
|:---------- |:------------- |:--------------|
-| | | |
+| tqdm_progress_bar | @shun-lin | shunlin@google.com |
## Contents
-| Submodule | Metric | Reference |
+| Submodule | Callback | Reference |
|:----------------------- |:-------------------|:---------------|
-| | | |
+| tqdm_progress_bar | TQDMProgressBar | https://tqdm.github.io/ |
## Contribution Guidelines
@@ -21,7 +21,7 @@ must:
* Add the addon to the `py_library` in this sub-package's BUILD file.
#### Testing Requirements
- * Simple unittests that demonstrate the metric is behaving as expected.
+ * Simple unittests that demonstrate the callback is behaving as expected.
* When applicable, run all unittests with TensorFlow's
`@run_in_graph_and_eager_modes` (for test method)
or `run_all_in_graph_and_eager_modes` (for TestCase subclass)
diff --git a/tensorflow_addons/callbacks/__init__.py b/tensorflow_addons/callbacks/__init__.py
index 3d79cb58bf..96c7cba775 100755
--- a/tensorflow_addons/callbacks/__init__.py
+++ b/tensorflow_addons/callbacks/__init__.py
@@ -17,3 +17,5 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
+from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar
\ No newline at end of file
diff --git a/tensorflow_addons/callbacks/tqdm_progress_bar.py b/tensorflow_addons/callbacks/tqdm_progress_bar.py
new file mode 100644
index 0000000000..be9d49e451
--- /dev/null
+++ b/tensorflow_addons/callbacks/tqdm_progress_bar.py
@@ -0,0 +1,202 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TQDM Progress Bar."""
+
+from __future__ import absolute_import, division, print_function
+
+import time
+from collections import defaultdict
+
+from tensorflow.keras.callbacks import Callback
+from tensorflow_addons.utils import keras_utils
+
+
+@keras_utils.register_keras_custom_object
+class TQDMProgressBar(Callback):
+ """TQDM Progress Bar for Tensorflow Keras.
+
+ Arguments:
+ metrics_separator (string): Custom separator between metrics.
+ Defaults to ' - '
+ overall_bar_format (string format): Custom bar format for overall
+ (outer) progress bar, see https://github.com/tqdm/tqdm#parameters
+ for more detail.
+ epoch_bar_format (string format): Custom bar format for epoch
+ (inner) progress bar, see https://github.com/tqdm/tqdm#parameters
+ for more detail.
+ update_per_second (int): Maximum number of updates in the epochs bar
+ per second, this is to prevent small batches from slowing down
+ training. Defaults to 10.
+ leave_epoch_progress (bool): True to leave epoch progress bars
+ leave_overall_progress (bool): True to leave overall progress bar
+ show_epoch_progress (bool): False to hide epoch progress bars
+ show_overall_progress (bool): False to hide overall progress bar
+ """
+
+ def __init__(self,
+ metrics_separator=" - ",
+ overall_bar_format='{l_bar}{bar} {n_fmt}/{total_fmt} ETA: '
+ '{remaining}s, {rate_fmt}{postfix}',
+ epoch_bar_format='{n_fmt}/{total_fmt}{bar} ETA: '
+ '{remaining}s - {desc}',
+ update_per_second=10,
+ leave_epoch_progress=True,
+ leave_overall_progress=True,
+ show_epoch_progress=True,
+ show_overall_progress=True):
+
+ try:
+ # import tqdm here because tqdm is not a required package
+ # for addons
+ import tqdm
+ version_message = 'Please update your TQDM version to >= 4.36.1, '
+ 'you have version {}. To update, run !pip install -U tqdm'
+ assert tqdm.__version__ >= '4.36.1', version_message.format(
+ tqdm.__version__)
+ from tqdm.auto import tqdm
+ self.tqdm = tqdm
+ except ImportError:
+ raise ImportError("Please install tqdm via pip install tqdm")
+
+ self.metrics_separator = metrics_separator
+ self.overall_bar_format = overall_bar_format
+ self.epoch_bar_format = epoch_bar_format
+ self.leave_epoch_progress = leave_epoch_progress
+ self.leave_overall_progress = leave_overall_progress
+ self.show_epoch_progress = show_epoch_progress
+ self.show_overall_progress = show_overall_progress
+
+ # compute update interval (inverse of update per second)
+ self.update_interval = 1 / update_per_second
+
+ self.last_update_time = time.time()
+ self.overall_progress_tqdm = None
+ self.epoch_progress_tqdm = None
+ self.num_epochs = None
+ self.logs = None
+ self.metrics = None
+
+ def on_train_begin(self, logs=None):
+ self.num_epochs = self.params['epochs']
+ self.metrics = self.params['metrics']
+
+ if self.show_overall_progress:
+ self.overall_progress_tqdm = self.tqdm(
+ desc='Training',
+ total=self.num_epochs,
+ bar_format=self.overall_bar_format,
+ leave=self.leave_overall_progress,
+ dynamic_ncols=True,
+ unit='epochs')
+
+ # set counting mode
+ if 'samples' in self.params:
+ self.mode = 'samples'
+ self.total_steps = self.params['samples']
+ else:
+ self.mode = 'steps'
+ self.total_steps = self.params['steps']
+
+ def on_train_end(self, logs={}):
+ if self.show_overall_progress:
+ self.overall_progress_tqdm.close()
+
+ def on_epoch_begin(self, epoch, logs={}):
+ current_epoch_description = "Epoch {epoch}/{num_epochs}".format(
+ epoch=epoch + 1, num_epochs=self.num_epochs)
+
+ if self.show_epoch_progress:
+ print(current_epoch_description)
+ self.epoch_progress_tqdm = self.tqdm(
+ total=self.total_steps,
+ bar_format=self.epoch_bar_format,
+ leave=self.leave_epoch_progress,
+ dynamic_ncols=True,
+ unit=self.mode)
+
+ self.seen = 0
+ self.steps_to_update = 0
+ self.logs = defaultdict(float)
+
+ def on_epoch_end(self, epoch, logs={}):
+
+ if self.show_epoch_progress:
+ metrics = self.format_metrics(logs)
+ self.epoch_progress_tqdm.desc = metrics
+
+ # set miniters and mininterval to 0 so last update displays
+ self.epoch_progress_tqdm.miniters = 0
+ self.epoch_progress_tqdm.mininterval = 0
+
+ # update the rest of the steps in epoch progress bar
+ self.epoch_progress_tqdm.update(self.total_steps -
+ self.epoch_progress_tqdm.n)
+ self.epoch_progress_tqdm.close()
+
+ if self.show_overall_progress:
+ self.overall_progress_tqdm.update(1)
+
+ def on_batch_end(self, batch, logs={}):
+ if self.mode == "samples":
+ batch_size = logs['size']
+ else:
+ batch_size = 1
+
+ self.seen += batch_size
+ self.steps_to_update += batch_size
+
+ if self.seen < self.total_steps:
+
+ for metric, value in logs.items():
+ self.logs[metric] += value * batch_size
+
+ now = time.time()
+ time_diff = now - self.last_update_time
+ if self.show_epoch_progress and time_diff >= self.update_interval:
+
+ # update the epoch progress bar
+ metrics = self.format_metrics(self.logs, self.seen)
+ self.epoch_progress_tqdm.desc = metrics
+ self.epoch_progress_tqdm.update(self.steps_to_update)
+
+ # reset steps to update
+ self.steps_to_update = 0
+
+ # update timestamp for last update
+ self.last_update_time = now
+
+ def format_metrics(self, logs={}, factor=1):
+ """Format metrics in logs into a string.
+
+ Arguments:
+ logs: dictionary of metrics and their values. Defaults to
+ empty dictionary.
+ factor (int): The factor we want to divide the metrics in logs
+ by, useful when we are computing the logs after each batch.
+ Defaults to 1.
+
+ Returns:
+ metrics_string: a string displaying metrics using the given
+ formators passed in through the constructor.
+ """
+
+ metric_value_pairs = []
+ for metric in self.metrics:
+ if metric in logs:
+ value = logs[metric] / factor
+ pair = '{name}: {value:0.4f}'.format(name=metric, value=value)
+ metric_value_pairs.append(pair)
+ metrics_string = self.metrics_separator.join(metric_value_pairs)
+ return metrics_string