From 76281cb82adeeb028eb9de4793048de04339bb04 Mon Sep 17 00:00:00 2001 From: Isha Arkatkar Date: Fri, 7 Apr 2023 19:15:22 -0700 Subject: [PATCH] Modify simple fed avg algorithm end-to-end test to use DTensor stack PiperOrigin-RevId: 522734269 --- .../examples/simple_fedavg/BUILD | 1 + .../simple_fedavg/simple_fedavg_test.py | 133 ++++++++++++++++-- .../simple_fedavg/simple_fedavg_tf.py | 48 +++++++ .../simple_fedavg/simple_fedavg_tff.py | 93 ++++++++++-- 4 files changed, 257 insertions(+), 18 deletions(-) diff --git a/tensorflow_federated/examples/simple_fedavg/BUILD b/tensorflow_federated/examples/simple_fedavg/BUILD index 089993549a..2bebcc6f71 100644 --- a/tensorflow_federated/examples/simple_fedavg/BUILD +++ b/tensorflow_federated/examples/simple_fedavg/BUILD @@ -36,6 +36,7 @@ py_cpu_gpu_test( name = "simple_fedavg_test", size = "medium", srcs = ["simple_fedavg_test.py"], + shard_count = 3, tags = ["nokokoro"], deps = [ ":simple_fedavg_tf", diff --git a/tensorflow_federated/examples/simple_fedavg/simple_fedavg_test.py b/tensorflow_federated/examples/simple_fedavg/simple_fedavg_test.py index 6315024b0f..fca1aea980 100644 --- a/tensorflow_federated/examples/simple_fedavg/simple_fedavg_test.py +++ b/tensorflow_federated/examples/simple_fedavg/simple_fedavg_test.py @@ -16,7 +16,6 @@ import collections from collections.abc import Callable import functools - from absl.testing import parameterized import attr import numpy as np @@ -55,6 +54,57 @@ def _create_test_cnn_model(): return model +NUM_LOGICAL_DEVICES = 8 + + +# Initialize logical devices only once for the module. +def setUpModule(): + gpu_devices = tf.config.list_physical_devices('GPU') + if gpu_devices: + tf.config.set_logical_device_configuration( + gpu_devices[0], + [ + tf.config.LogicalDeviceConfiguration(memory_limit=200), + ] + * NUM_LOGICAL_DEVICES, + ) + tf.experimental.dtensor.initialize_accelerator_system('GPU') + else: + devices = tf.config.list_physical_devices('CPU') + tf.config.set_logical_device_configuration( + devices[0], + [ + tf.config.LogicalDeviceConfiguration(), + ] + * NUM_LOGICAL_DEVICES, + ) + + +def _setup_local_context( + use_dtensor_stack_on_server=False, use_dtensor_stack_on_client=False +): + device_name = tf.experimental.dtensor.preferred_device_type() + if not use_dtensor_stack_on_server and not use_dtensor_stack_on_client: + # Use Tensorflow executor path + tff.backends.native.set_sync_local_cpp_execution_context() + return + + # Go through DTENSOR path + mesh_dim_name = 'batch' + mesh = tf.experimental.dtensor.create_mesh( + devices=[device_name + ':%d' % i for i in range(NUM_LOGICAL_DEVICES)], + mesh_dims=[(mesh_dim_name, NUM_LOGICAL_DEVICES)], + ) + server_mesh = mesh if use_dtensor_stack_on_server else None + client_mesh = mesh if use_dtensor_stack_on_client else None + tff._native_cpp_execution_contexts.set_sync_experimental_distributed_cpp_execution_context( + distributed_config=tff._native_cpp_execution_contexts.DistributedConfiguration( + server_mesh=server_mesh, + client_mesh=client_mesh, + ) + ) + + def _create_random_batch(): return collections.OrderedDict( x=tf.random.uniform(tf.TensorShape([1, 28, 28, 1]), dtype=tf.float32), @@ -214,11 +264,43 @@ def client_data(): return client_data +@parameterized.named_parameters( + ( + 'dtensor_on_both_server_client', + True, + True, + True, + ), # DTensor on client can only be used with sequence_reduce + ( + 'dtensor_server_side_use_dataset_iteration', + True, + False, + False, + ), + ( + 'dtensor_on_only_client', + False, + True, + True, + ), # DTensor on client can only be used with sequence_reduce + ('tensorflow_use_sequence_reduce', False, False, True), + ('tensorflow_use_dataset_iteration', False, False, False), +) class SimpleFedAvgTest(tf.test.TestCase, parameterized.TestCase): - def test_process_construction(self): + def test_process_construction( + self, + use_dtensor_stack_on_server, + use_dtensor_stack_on_client, + use_sequence_reduce, + ): + _setup_local_context( + use_dtensor_stack_on_server=use_dtensor_stack_on_server, + use_dtensor_stack_on_client=use_dtensor_stack_on_client, + ) it_process = simple_fedavg_tff.build_federated_averaging_process( - _tff_learning_model_fn + _tff_learning_model_fn, + use_sequence_reduce=use_sequence_reduce, ) self.assertIsInstance(it_process, tff.templates.IterativeProcess) federated_data_type = it_process.next.type_signature.parameter[1] @@ -236,9 +318,18 @@ def test_process_construction(self): ), ) - def test_training_keras_model_converges(self): + def test_training_keras_model_converges( + self, + use_dtensor_stack_on_server, + use_dtensor_stack_on_client, + use_sequence_reduce, + ): + _setup_local_context( + use_dtensor_stack_on_server=use_dtensor_stack_on_server, + use_dtensor_stack_on_client=use_dtensor_stack_on_client, + ) it_process = simple_fedavg_tff.build_federated_averaging_process( - _tff_learning_model_fn + _tff_learning_model_fn, use_sequence_reduce=use_sequence_reduce ) server_state = it_process.initialize() @@ -260,11 +351,22 @@ def deterministic_batch(): previous_loss = loss self.assertLess(loss, 0.1) - def test_training_custom_model_converges(self): + def test_training_custom_model_converges( + self, + use_dtensor_stack_on_server, + use_dtensor_stack_on_client, + use_sequence_reduce, + ): + _setup_local_context( + use_dtensor_stack_on_server=use_dtensor_stack_on_server, + use_dtensor_stack_on_client=use_dtensor_stack_on_client, + ) client_data = _create_client_data() train_data = [client_data()] - trainer = simple_fedavg_tff.build_federated_averaging_process(MnistModel) + trainer = simple_fedavg_tff.build_federated_averaging_process( + MnistModel, use_sequence_reduce=use_sequence_reduce + ) state = trainer.initialize() previous_loss = None for _ in range(10): @@ -418,11 +520,24 @@ def test_build_fedavg_process(self): ), ) - def test_client_adagrad_train(self): + @parameterized.named_parameters( + ( + 'dtensor_server_side_use_dataset_iteration', + True, + False, + ), # DTensor can only be used on server side for RNNs, since there is a + # # loop inside tf function on client side which is not supported with + # # DTensor. + ('tensorflow_use_dataset_iteration', False, False), + ) + def test_client_adagrad_train( + self, use_dtensor_on_server, use_dtensor_on_client + ): + _setup_local_context(use_dtensor_on_server, use_dtensor_on_client) it_process = simple_fedavg_tff.build_federated_averaging_process( _rnn_model_fn, client_optimizer_fn=functools.partial( - tf.keras.optimizers.legacy.Adagrad, learning_rate=0.01 + tf.keras.optimizers.SGD, learning_rate=0.01 ), ) server_state = it_process.initialize() diff --git a/tensorflow_federated/examples/simple_fedavg/simple_fedavg_tf.py b/tensorflow_federated/examples/simple_fedavg/simple_fedavg_tf.py index 3eb5c3d0ea..38295c84eb 100644 --- a/tensorflow_federated/examples/simple_fedavg/simple_fedavg_tf.py +++ b/tensorflow_federated/examples/simple_fedavg/simple_fedavg_tf.py @@ -142,6 +142,54 @@ def build_server_broadcast_message(server_state): ) +@tf.function +def batch_client_update( + model, batch, initial_weights, num_examples, client_optimizer +): + """Performs client local training of `model` on `dataset`. + + Args: + model: A `tff.learning.models.VariableModel` to train locally on the client. + batch: A batch from 'tf.data.Dataset' representing the clients local data. + initial_weights: initial model weights to use for update. weights to train. + num_examples: Number of examples observed so far. + client_optimizer: A `tf.keras.optimizers.Optimizer` used to update the local + model during training. + + Returns: + A `ClientOutput` instance with a model update to aggregate on the server. + """ + model_weights = tff.learning.models.ModelWeights.from_model(model) + tf.nest.map_structure( + lambda v, t: v.assign(t), model_weights.trainable, initial_weights + ) + + num_examples = tf.cast(num_examples, tf.int32) + with tf.GradientTape() as tape: + outputs = model.forward_pass(batch) + grads = tape.gradient(outputs.loss, model_weights.trainable) + client_optimizer.apply_gradients(zip(grads, model_weights.trainable)) + batch_size = tf.shape(batch['y'])[0] + num_examples += batch_size + + weights_delta = tf.nest.map_structure( + lambda a, b: a - b, model_weights.trainable, initial_weights + ) + client_weight = tf.cast(num_examples, tf.float32) + model_outputs = model.report_local_unfinalized_metrics() + return ClientOutput(weights_delta, client_weight, model_outputs) + + +@tf.function +def init_client_ouput(model, server_message): + client_weight = tf.constant(0, dtype=tf.float32) + return ClientOutput( + server_message.model_weights.trainable, + client_weight, + model.report_local_unfinalized_metrics(), + ) + + @tf.function def client_update(model, dataset, server_message, client_optimizer): """Performans client local training of `model` on `dataset`. diff --git a/tensorflow_federated/examples/simple_fedavg/simple_fedavg_tff.py b/tensorflow_federated/examples/simple_fedavg/simple_fedavg_tff.py index fd5586de8c..d300f472ee 100644 --- a/tensorflow_federated/examples/simple_fedavg/simple_fedavg_tff.py +++ b/tensorflow_federated/examples/simple_fedavg/simple_fedavg_tff.py @@ -29,8 +29,10 @@ import tensorflow as tf import tensorflow_federated as tff +from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import batch_client_update from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import build_server_broadcast_message from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import client_update +from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import init_client_ouput from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import server_update from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import ServerState @@ -51,10 +53,63 @@ def _initialize_optimizer_vars( assert optimizer.variables() +def _build_client_update_fn( + tf_dataset_type, + server_message_type, + model_fn, + client_optimizer_fn, + use_sequence_reduce, +): + """Returns computatoin for client update.""" + + @tff.tf_computation(server_message_type, tf_dataset_type) + def client_update_fn(server_message, tf_dataset): + model = model_fn() + client_optimizer = client_optimizer_fn() + return client_update(model, tf_dataset, server_message, client_optimizer) + + if not use_sequence_reduce: + # Use client update fn with dataset iteration inside tf function. + return client_update_fn + else: + # Use client update function with dataset iteration lifter out of tf + # function, using tff.sequenece_reduce. + + client_update_type_spec = client_update_fn.type_signature.result + batch_type = model_fn().input_spec + + @tff.tf_computation(client_update_type_spec, batch_type) + def client_update_batch_fn(client_data, batch): + model = model_fn() + client_optimizer = client_optimizer_fn() + return batch_client_update( + model, + batch, + client_data.weights_delta, + client_data.client_weight, + client_optimizer, + ) + + @tff.tf_computation(server_message_type) + def initialize_client_data(server_message): + model = model_fn() + return init_client_ouput(model, server_message) + + @tff.federated_computation( + server_message_type, tff.SequenceType(batch_type) + ) + def client_update_weights_fn(server_message, batches): + client_data = initialize_client_data(server_message) + return tff.sequence_reduce(batches, client_data, client_update_batch_fn) + + return client_update_weights_fn + + def build_federated_averaging_process( model_fn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02), + use_sequence_reduce=False, ): """Builds the TFF computations for optimization using federated averaging. @@ -65,6 +120,9 @@ def build_federated_averaging_process( `tf.keras.optimizers.Optimizer` for server update. client_optimizer_fn: A no-arg function that returns a `tf.keras.optimizers.Optimizer` for client update. + use_sequence_reduce: If true, uses tff.sequence_reduce to perform reduction + across batches of dataset, instead of a for-loop inside client update + method. Returns: A `tff.templates.IterativeProcess`. @@ -78,7 +136,12 @@ def build_federated_averaging_process( whimsy_model.metric_finalizers(), unfinalized_metrics_type ) - @tff.tf_computation + @tff.tf_computation( + layout_map={ + 'weights': 'batch,unsharded', + 'SGD/m/weights': 'batch,unsharded', + }, + ) def server_init_tf(): model = model_fn() model_weights = tff.learning.models.ModelWeights.from_model(model) @@ -93,7 +156,14 @@ def server_init_tf(): server_state_type = server_init_tf.type_signature.result model_weights_type = server_state_type.model - @tff.tf_computation(server_state_type, model_weights_type.trainable) + @tff.tf_computation( + server_state_type, + model_weights_type.trainable, + layout_map={ + 'weights': 'batch,unsharded', + 'SGD/m/weights': 'batch,unsharded', + }, + ) def server_update_fn(server_state, model_delta): model = model_fn() server_optimizer = server_optimizer_fn() @@ -107,12 +177,6 @@ def server_message_fn(server_state): server_message_type = server_message_fn.type_signature.result tf_dataset_type = tff.SequenceType(whimsy_model.input_spec) - @tff.tf_computation(tf_dataset_type, server_message_type) - def client_update_fn(tf_dataset, server_message): - model = model_fn() - client_optimizer = client_optimizer_fn() - return client_update(model, tf_dataset, server_message, client_optimizer) - federated_server_state_type = tff.type_at_server(server_state_type) federated_dataset_type = tff.type_at_clients(tf_dataset_type) @@ -132,10 +196,21 @@ def run_one_round(server_state, federated_dataset): A tuple of updated `ServerState` and `tf.Tensor` of average loss. """ server_message = tff.federated_map(server_message_fn, server_state) + client_update_fn = _build_client_update_fn( + tf_dataset_type, + server_message_type, + model_fn, + client_optimizer_fn, + use_sequence_reduce, + ) server_message_at_client = tff.federated_broadcast(server_message) client_outputs = tff.federated_map( - client_update_fn, (federated_dataset, server_message_at_client) + client_update_fn, + ( + server_message_at_client, + federated_dataset, + ), ) weight_denom = client_outputs.client_weight