Skip to content

Commit

Permalink
Clean up TB summary logs.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijinl committed Jul 16, 2024
1 parent ac14832 commit 74dede1
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 11 deletions.
Binary file modified examples/getting_started/tf/figs/fedavg-diff-algos.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/getting_started/tf/figs/fedavg-diff-alphas.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/getting_started/tf/figs/fedavg-vs-centralized.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 6 additions & 11 deletions examples/getting_started/tf/src/cifar10_tf_fl_alpha_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
# (1) import nvflare client API
import nvflare.client as flare

# (optional) metrics
from nvflare.client.tracking import SummaryWriter


PATH = "./tf_model.weights.h5"

Expand Down Expand Up @@ -112,7 +109,7 @@ def preprocess_dataset(dataset, is_training, batch_size=1):
Tensorflow Dataset with pre-processings applied.
"""
# Values from: https://github.com/NVIDIA/NVFlare/blob/fc2bc47889b980c8de37de5528e3d07e6b1a942e/examples/advanced/cifar10/pt/learners/cifar10_model_learner.py#L147
# Values from: https://github.com/NVIDIA/NVFlare/blob/main/examples/advanced/cifar10/pt/learners/cifar10_model_learner.py#L147
mean_cifar10 = tf.constant([125.3, 123.0, 113.9], dtype=tf.float32)
std_cifar10 = tf.constant([63.0, 62.1, 66.7], dtype=tf.float32)

Expand Down Expand Up @@ -202,7 +199,10 @@ def main():
model = ModerateTFNet()
model.build(input_shape=(None, 32, 32, 3))

callbacks = [tf.keras.callbacks.TensorBoard(log_dir="./logs_keras", write_graph=False)]
# Tensorboard logs for each local training epoch
callbacks = [tf.keras.callbacks.TensorBoard(log_dir="./logs/epochs", write_graph=False)]
# Tensorboard logs for each aggregation run
tf_summary_writer = tf.summary.create_file_writer(logdir="./logs/rounds")

# Control whether FedProx is used.
if args.fedprox_mu > 0:
Expand All @@ -219,8 +219,6 @@ def main():
# (2) initializes NVFlare client API
flare.init()

summary_writer = SummaryWriter()
tf_summary_writer = tf.summary.create_file_writer(logdir="./logs/validation")
while flare.is_running():
# (3) receives FLModel from NVFlare
input_model = flare.receive()
Expand All @@ -236,7 +234,6 @@ def main():

# (5) evaluate aggregated/received model
_, test_global_acc = model.evaluate(x=test_ds, verbose=2)
summary_writer.add_scalar(tag="global_model_accuracy", scalar=test_global_acc, global_step=input_model.current_round)

with tf_summary_writer.as_default():
tf.summary.scalar("global_model_accuracy", test_global_acc, input_model.current_round)
Expand All @@ -254,7 +251,7 @@ def main():
validation_data=test_ds,
callbacks=callbacks,
initial_epoch=start_epoch,
validation_freq=1 #args.epochs
validation_freq=1
)

print("Finished Training")
Expand All @@ -263,8 +260,6 @@ def main():

_, test_acc = model.evaluate(x=test_ds, verbose=2)

summary_writer.add_scalar(tag="local_model_accuracy", scalar=test_acc, global_step=input_model.current_round)

with tf_summary_writer.as_default():
tf.summary.scalar("local_model_accuracy", test_acc, input_model.current_round)
print(f"Accuracy of the model on the {len(test_images)} test images: {test_acc * 100} %")
Expand Down

0 comments on commit 74dede1

Please sign in to comment.