Skip to content

Commit

Permalink
Modify simple fed avg algorithm end-to-end test to use DTensor stack
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 522734269
  • Loading branch information
ishark authored and tensorflow-copybara committed Apr 8, 2023
1 parent 08c5fd0 commit 76281cb
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 18 deletions.
1 change: 1 addition & 0 deletions tensorflow_federated/examples/simple_fedavg/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ py_cpu_gpu_test(
name = "simple_fedavg_test",
size = "medium",
srcs = ["simple_fedavg_test.py"],
shard_count = 3,
tags = ["nokokoro"],
deps = [
":simple_fedavg_tf",
Expand Down
133 changes: 124 additions & 9 deletions tensorflow_federated/examples/simple_fedavg/simple_fedavg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import collections
from collections.abc import Callable
import functools

from absl.testing import parameterized
import attr
import numpy as np
Expand Down Expand Up @@ -55,6 +54,57 @@ def _create_test_cnn_model():
return model


NUM_LOGICAL_DEVICES = 8


# Initialize logical devices only once for the module.
def setUpModule():
gpu_devices = tf.config.list_physical_devices('GPU')
if gpu_devices:
tf.config.set_logical_device_configuration(
gpu_devices[0],
[
tf.config.LogicalDeviceConfiguration(memory_limit=200),
]
* NUM_LOGICAL_DEVICES,
)
tf.experimental.dtensor.initialize_accelerator_system('GPU')
else:
devices = tf.config.list_physical_devices('CPU')
tf.config.set_logical_device_configuration(
devices[0],
[
tf.config.LogicalDeviceConfiguration(),
]
* NUM_LOGICAL_DEVICES,
)


def _setup_local_context(
use_dtensor_stack_on_server=False, use_dtensor_stack_on_client=False
):
device_name = tf.experimental.dtensor.preferred_device_type()
if not use_dtensor_stack_on_server and not use_dtensor_stack_on_client:
# Use Tensorflow executor path
tff.backends.native.set_sync_local_cpp_execution_context()
return

# Go through DTENSOR path
mesh_dim_name = 'batch'
mesh = tf.experimental.dtensor.create_mesh(
devices=[device_name + ':%d' % i for i in range(NUM_LOGICAL_DEVICES)],
mesh_dims=[(mesh_dim_name, NUM_LOGICAL_DEVICES)],
)
server_mesh = mesh if use_dtensor_stack_on_server else None
client_mesh = mesh if use_dtensor_stack_on_client else None
tff._native_cpp_execution_contexts.set_sync_experimental_distributed_cpp_execution_context(
distributed_config=tff._native_cpp_execution_contexts.DistributedConfiguration(
server_mesh=server_mesh,
client_mesh=client_mesh,
)
)


def _create_random_batch():
return collections.OrderedDict(
x=tf.random.uniform(tf.TensorShape([1, 28, 28, 1]), dtype=tf.float32),
Expand Down Expand Up @@ -214,11 +264,43 @@ def client_data():
return client_data


@parameterized.named_parameters(
(
'dtensor_on_both_server_client',
True,
True,
True,
), # DTensor on client can only be used with sequence_reduce
(
'dtensor_server_side_use_dataset_iteration',
True,
False,
False,
),
(
'dtensor_on_only_client',
False,
True,
True,
), # DTensor on client can only be used with sequence_reduce
('tensorflow_use_sequence_reduce', False, False, True),
('tensorflow_use_dataset_iteration', False, False, False),
)
class SimpleFedAvgTest(tf.test.TestCase, parameterized.TestCase):

def test_process_construction(self):
def test_process_construction(
self,
use_dtensor_stack_on_server,
use_dtensor_stack_on_client,
use_sequence_reduce,
):
_setup_local_context(
use_dtensor_stack_on_server=use_dtensor_stack_on_server,
use_dtensor_stack_on_client=use_dtensor_stack_on_client,
)
it_process = simple_fedavg_tff.build_federated_averaging_process(
_tff_learning_model_fn
_tff_learning_model_fn,
use_sequence_reduce=use_sequence_reduce,
)
self.assertIsInstance(it_process, tff.templates.IterativeProcess)
federated_data_type = it_process.next.type_signature.parameter[1]
Expand All @@ -236,9 +318,18 @@ def test_process_construction(self):
),
)

def test_training_keras_model_converges(self):
def test_training_keras_model_converges(
self,
use_dtensor_stack_on_server,
use_dtensor_stack_on_client,
use_sequence_reduce,
):
_setup_local_context(
use_dtensor_stack_on_server=use_dtensor_stack_on_server,
use_dtensor_stack_on_client=use_dtensor_stack_on_client,
)
it_process = simple_fedavg_tff.build_federated_averaging_process(
_tff_learning_model_fn
_tff_learning_model_fn, use_sequence_reduce=use_sequence_reduce
)
server_state = it_process.initialize()

Expand All @@ -260,11 +351,22 @@ def deterministic_batch():
previous_loss = loss
self.assertLess(loss, 0.1)

def test_training_custom_model_converges(self):
def test_training_custom_model_converges(
self,
use_dtensor_stack_on_server,
use_dtensor_stack_on_client,
use_sequence_reduce,
):
_setup_local_context(
use_dtensor_stack_on_server=use_dtensor_stack_on_server,
use_dtensor_stack_on_client=use_dtensor_stack_on_client,
)
client_data = _create_client_data()
train_data = [client_data()]

trainer = simple_fedavg_tff.build_federated_averaging_process(MnistModel)
trainer = simple_fedavg_tff.build_federated_averaging_process(
MnistModel, use_sequence_reduce=use_sequence_reduce
)
state = trainer.initialize()
previous_loss = None
for _ in range(10):
Expand Down Expand Up @@ -418,11 +520,24 @@ def test_build_fedavg_process(self):
),
)

def test_client_adagrad_train(self):
@parameterized.named_parameters(
(
'dtensor_server_side_use_dataset_iteration',
True,
False,
), # DTensor can only be used on server side for RNNs, since there is a
# # loop inside tf function on client side which is not supported with
# # DTensor.
('tensorflow_use_dataset_iteration', False, False),
)
def test_client_adagrad_train(
self, use_dtensor_on_server, use_dtensor_on_client
):
_setup_local_context(use_dtensor_on_server, use_dtensor_on_client)
it_process = simple_fedavg_tff.build_federated_averaging_process(
_rnn_model_fn,
client_optimizer_fn=functools.partial(
tf.keras.optimizers.legacy.Adagrad, learning_rate=0.01
tf.keras.optimizers.SGD, learning_rate=0.01
),
)
server_state = it_process.initialize()
Expand Down
48 changes: 48 additions & 0 deletions tensorflow_federated/examples/simple_fedavg/simple_fedavg_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,54 @@ def build_server_broadcast_message(server_state):
)


@tf.function
def batch_client_update(
model, batch, initial_weights, num_examples, client_optimizer
):
"""Performs client local training of `model` on `dataset`.
Args:
model: A `tff.learning.models.VariableModel` to train locally on the client.
batch: A batch from 'tf.data.Dataset' representing the clients local data.
initial_weights: initial model weights to use for update. weights to train.
num_examples: Number of examples observed so far.
client_optimizer: A `tf.keras.optimizers.Optimizer` used to update the local
model during training.
Returns:
A `ClientOutput` instance with a model update to aggregate on the server.
"""
model_weights = tff.learning.models.ModelWeights.from_model(model)
tf.nest.map_structure(
lambda v, t: v.assign(t), model_weights.trainable, initial_weights
)

num_examples = tf.cast(num_examples, tf.int32)
with tf.GradientTape() as tape:
outputs = model.forward_pass(batch)
grads = tape.gradient(outputs.loss, model_weights.trainable)
client_optimizer.apply_gradients(zip(grads, model_weights.trainable))
batch_size = tf.shape(batch['y'])[0]
num_examples += batch_size

weights_delta = tf.nest.map_structure(
lambda a, b: a - b, model_weights.trainable, initial_weights
)
client_weight = tf.cast(num_examples, tf.float32)
model_outputs = model.report_local_unfinalized_metrics()
return ClientOutput(weights_delta, client_weight, model_outputs)


@tf.function
def init_client_ouput(model, server_message):
client_weight = tf.constant(0, dtype=tf.float32)
return ClientOutput(
server_message.model_weights.trainable,
client_weight,
model.report_local_unfinalized_metrics(),
)


@tf.function
def client_update(model, dataset, server_message, client_optimizer):
"""Performans client local training of `model` on `dataset`.
Expand Down
93 changes: 84 additions & 9 deletions tensorflow_federated/examples/simple_fedavg/simple_fedavg_tff.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
import tensorflow as tf
import tensorflow_federated as tff

from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import batch_client_update
from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import build_server_broadcast_message
from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import client_update
from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import init_client_ouput
from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import server_update
from tensorflow_federated.examples.simple_fedavg.simple_fedavg_tf import ServerState

Expand All @@ -51,10 +53,63 @@ def _initialize_optimizer_vars(
assert optimizer.variables()


def _build_client_update_fn(
tf_dataset_type,
server_message_type,
model_fn,
client_optimizer_fn,
use_sequence_reduce,
):
"""Returns computatoin for client update."""

@tff.tf_computation(server_message_type, tf_dataset_type)
def client_update_fn(server_message, tf_dataset):
model = model_fn()
client_optimizer = client_optimizer_fn()
return client_update(model, tf_dataset, server_message, client_optimizer)

if not use_sequence_reduce:
# Use client update fn with dataset iteration inside tf function.
return client_update_fn
else:
# Use client update function with dataset iteration lifter out of tf
# function, using tff.sequenece_reduce.

client_update_type_spec = client_update_fn.type_signature.result
batch_type = model_fn().input_spec

@tff.tf_computation(client_update_type_spec, batch_type)
def client_update_batch_fn(client_data, batch):
model = model_fn()
client_optimizer = client_optimizer_fn()
return batch_client_update(
model,
batch,
client_data.weights_delta,
client_data.client_weight,
client_optimizer,
)

@tff.tf_computation(server_message_type)
def initialize_client_data(server_message):
model = model_fn()
return init_client_ouput(model, server_message)

@tff.federated_computation(
server_message_type, tff.SequenceType(batch_type)
)
def client_update_weights_fn(server_message, batches):
client_data = initialize_client_data(server_message)
return tff.sequence_reduce(batches, client_data, client_update_batch_fn)

return client_update_weights_fn


def build_federated_averaging_process(
model_fn,
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
use_sequence_reduce=False,
):
"""Builds the TFF computations for optimization using federated averaging.
Expand All @@ -65,6 +120,9 @@ def build_federated_averaging_process(
`tf.keras.optimizers.Optimizer` for server update.
client_optimizer_fn: A no-arg function that returns a
`tf.keras.optimizers.Optimizer` for client update.
use_sequence_reduce: If true, uses tff.sequence_reduce to perform reduction
across batches of dataset, instead of a for-loop inside client update
method.
Returns:
A `tff.templates.IterativeProcess`.
Expand All @@ -78,7 +136,12 @@ def build_federated_averaging_process(
whimsy_model.metric_finalizers(), unfinalized_metrics_type
)

@tff.tf_computation
@tff.tf_computation(
layout_map={
'weights': 'batch,unsharded',
'SGD/m/weights': 'batch,unsharded',
},
)
def server_init_tf():
model = model_fn()
model_weights = tff.learning.models.ModelWeights.from_model(model)
Expand All @@ -93,7 +156,14 @@ def server_init_tf():
server_state_type = server_init_tf.type_signature.result
model_weights_type = server_state_type.model

@tff.tf_computation(server_state_type, model_weights_type.trainable)
@tff.tf_computation(
server_state_type,
model_weights_type.trainable,
layout_map={
'weights': 'batch,unsharded',
'SGD/m/weights': 'batch,unsharded',
},
)
def server_update_fn(server_state, model_delta):
model = model_fn()
server_optimizer = server_optimizer_fn()
Expand All @@ -107,12 +177,6 @@ def server_message_fn(server_state):
server_message_type = server_message_fn.type_signature.result
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

@tff.tf_computation(tf_dataset_type, server_message_type)
def client_update_fn(tf_dataset, server_message):
model = model_fn()
client_optimizer = client_optimizer_fn()
return client_update(model, tf_dataset, server_message, client_optimizer)

federated_server_state_type = tff.type_at_server(server_state_type)
federated_dataset_type = tff.type_at_clients(tf_dataset_type)

Expand All @@ -132,10 +196,21 @@ def run_one_round(server_state, federated_dataset):
A tuple of updated `ServerState` and `tf.Tensor` of average loss.
"""
server_message = tff.federated_map(server_message_fn, server_state)
client_update_fn = _build_client_update_fn(
tf_dataset_type,
server_message_type,
model_fn,
client_optimizer_fn,
use_sequence_reduce,
)
server_message_at_client = tff.federated_broadcast(server_message)

client_outputs = tff.federated_map(
client_update_fn, (federated_dataset, server_message_at_client)
client_update_fn,
(
server_message_at_client,
federated_dataset,
),
)

weight_denom = client_outputs.client_weight
Expand Down

0 comments on commit 76281cb

Please sign in to comment.