From 9cc8790aaa736a4bd33750b1ce47b9969bebdc6e Mon Sep 17 00:00:00 2001 From: Shailesh Tanwar <135304487+tanwarsh@users.noreply.github.com> Date: Tue, 11 Feb 2025 09:11:29 +0530 Subject: [PATCH] TensorFlow Taskrunner Workspace with Keras 3 (#1330) * code changes Signed-off-by: yes * code changes Signed-off-by: yes * code changes Signed-off-by: yes * code changes Signed-off-by: yes * formatting changes Signed-off-by: yes * code changes Signed-off-by: yes * code changes Signed-off-by: yes * code changes Signed-off-by: yes * changed dockstring Signed-off-by: yes --------- Signed-off-by: yes --- .../keras/tensorflow/mnist/.workspace | 2 + .../keras/tensorflow/mnist/plan/cols.yaml | 5 + .../keras/tensorflow/mnist/plan/data.yaml | 7 ++ .../keras/tensorflow/mnist/plan/defaults | 2 + .../keras/tensorflow/mnist/plan/plan.yaml | 42 +++++++ .../keras/tensorflow/mnist/requirements.txt | 2 + .../keras/tensorflow/mnist/src/__init__.py | 3 + .../keras/tensorflow/mnist/src/dataloader.py | 42 +++++++ .../keras/tensorflow/mnist/src/mnist_utils.py | 118 ++++++++++++++++++ .../keras/tensorflow/mnist/src/model.py | 96 ++++++++++++++ .../keras/tensorflow/mnist/src/taskrunner.py | 90 +++++++++++++ 11 files changed, 409 insertions(+) create mode 100644 openfl-workspace/keras/tensorflow/mnist/.workspace create mode 100644 openfl-workspace/keras/tensorflow/mnist/plan/cols.yaml create mode 100644 openfl-workspace/keras/tensorflow/mnist/plan/data.yaml create mode 100644 openfl-workspace/keras/tensorflow/mnist/plan/defaults create mode 100644 openfl-workspace/keras/tensorflow/mnist/plan/plan.yaml create mode 100644 openfl-workspace/keras/tensorflow/mnist/requirements.txt create mode 100644 openfl-workspace/keras/tensorflow/mnist/src/__init__.py create mode 100644 openfl-workspace/keras/tensorflow/mnist/src/dataloader.py create mode 100644 openfl-workspace/keras/tensorflow/mnist/src/mnist_utils.py create mode 100644 openfl-workspace/keras/tensorflow/mnist/src/model.py create mode 100644 openfl-workspace/keras/tensorflow/mnist/src/taskrunner.py diff --git a/openfl-workspace/keras/tensorflow/mnist/.workspace b/openfl-workspace/keras/tensorflow/mnist/.workspace new file mode 100644 index 0000000000..3c2c5d08b4 --- /dev/null +++ b/openfl-workspace/keras/tensorflow/mnist/.workspace @@ -0,0 +1,2 @@ +current_plan_name: default + diff --git a/openfl-workspace/keras/tensorflow/mnist/plan/cols.yaml b/openfl-workspace/keras/tensorflow/mnist/plan/cols.yaml new file mode 100644 index 0000000000..b15fc13e97 --- /dev/null +++ b/openfl-workspace/keras/tensorflow/mnist/plan/cols.yaml @@ -0,0 +1,5 @@ +# Copyright (C) 2020-2025 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +collaborators: + \ No newline at end of file diff --git a/openfl-workspace/keras/tensorflow/mnist/plan/data.yaml b/openfl-workspace/keras/tensorflow/mnist/plan/data.yaml new file mode 100644 index 0000000000..9a841a6662 --- /dev/null +++ b/openfl-workspace/keras/tensorflow/mnist/plan/data.yaml @@ -0,0 +1,7 @@ +# Copyright (C) 2020-2025 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +# collaborator_name,data_directory_path +one,1 + + diff --git a/openfl-workspace/keras/tensorflow/mnist/plan/defaults b/openfl-workspace/keras/tensorflow/mnist/plan/defaults new file mode 100644 index 0000000000..fb82f9c5b6 --- /dev/null +++ b/openfl-workspace/keras/tensorflow/mnist/plan/defaults @@ -0,0 +1,2 @@ +../../workspace/plan/defaults + diff --git a/openfl-workspace/keras/tensorflow/mnist/plan/plan.yaml b/openfl-workspace/keras/tensorflow/mnist/plan/plan.yaml new file mode 100644 index 0000000000..777b94cddf --- /dev/null +++ b/openfl-workspace/keras/tensorflow/mnist/plan/plan.yaml @@ -0,0 +1,42 @@ +# Copyright (C) 2020-2025 Intel Corporation +# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. + +aggregator : + defaults : plan/defaults/aggregator.yaml + template : openfl.component.Aggregator + settings : + init_state_path : save/init.pbuf + best_state_path : save/best.pbuf + last_state_path : save/last.pbuf + rounds_to_train : 10 + +collaborator : + defaults : plan/defaults/collaborator.yaml + template : openfl.component.Collaborator + settings : + delta_updates : false + opt_treatment : RESET + +data_loader : + defaults : plan/defaults/data_loader.yaml + template : src.dataloader.MNISTInMemory + settings : + collaborator_count : 2 + data_group_name : mnist + batch_size : 256 + +task_runner : + defaults : plan/defaults/task_runner.yaml + template : src.taskrunner.CNNTaskruner + +network : + defaults : plan/defaults/network.yaml + +assigner : + defaults : plan/defaults/assigner.yaml + +tasks : + defaults : plan/defaults/tasks_keras.yaml + +compression_pipeline : + defaults : plan/defaults/compression_pipeline.yaml \ No newline at end of file diff --git a/openfl-workspace/keras/tensorflow/mnist/requirements.txt b/openfl-workspace/keras/tensorflow/mnist/requirements.txt new file mode 100644 index 0000000000..9f6cb9c0df --- /dev/null +++ b/openfl-workspace/keras/tensorflow/mnist/requirements.txt @@ -0,0 +1,2 @@ +keras==3.8.0 +tensorflow==2.18.0 \ No newline at end of file diff --git a/openfl-workspace/keras/tensorflow/mnist/src/__init__.py b/openfl-workspace/keras/tensorflow/mnist/src/__init__.py new file mode 100644 index 0000000000..035ee4d0ae --- /dev/null +++ b/openfl-workspace/keras/tensorflow/mnist/src/__init__.py @@ -0,0 +1,3 @@ +# Copyright (C) 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""You may copy this file as the starting point of your own model.""" diff --git a/openfl-workspace/keras/tensorflow/mnist/src/dataloader.py b/openfl-workspace/keras/tensorflow/mnist/src/dataloader.py new file mode 100644 index 0000000000..00ad827a56 --- /dev/null +++ b/openfl-workspace/keras/tensorflow/mnist/src/dataloader.py @@ -0,0 +1,42 @@ +# Copyright (C) 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +from openfl.federated import KerasDataLoader +from .mnist_utils import load_mnist_shard + + +class MNISTInMemory(KerasDataLoader): + """Data Loader for MNIST Dataset.""" + + def __init__(self, data_path, batch_size, **kwargs): + """ + Initialize. + + Args: + data_path: File path for the dataset + batch_size (int): The batch size for the data loader + **kwargs: Additional arguments, passed to super init and load_mnist_shard + """ + super().__init__(batch_size, **kwargs) + + try: + int(data_path) + except: + raise ValueError( + "Expected `%s` to be representable as `int`, as it refers to the data shard " + + "number used by the collaborator.", + data_path + ) + + _, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard( + shard_num=int(data_path), **kwargs + ) + + self.X_train = X_train + self.y_train = y_train + self.X_valid = X_valid + self.y_valid = y_valid + + self.num_classes = num_classes diff --git a/openfl-workspace/keras/tensorflow/mnist/src/mnist_utils.py b/openfl-workspace/keras/tensorflow/mnist/src/mnist_utils.py new file mode 100644 index 0000000000..33a3d93e26 --- /dev/null +++ b/openfl-workspace/keras/tensorflow/mnist/src/mnist_utils.py @@ -0,0 +1,118 @@ +# Copyright (C) 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +from logging import getLogger + +import numpy as np +from tensorflow.python.keras.utils.data_utils import get_file + +logger = getLogger(__name__) + + +def one_hot(labels, classes): + """ + One Hot encode a vector. + + Args: + labels (list): List of labels to onehot encode + classes (int): Total number of categorical classes + + Returns: + np.array: Matrix of one-hot encoded labels + """ + return np.eye(classes)[labels] + + +def _load_raw_datashards(shard_num, collaborator_count): + """ + Load the raw data by shard. + + Returns tuples of the dataset shard divided into training and validation. + + Args: + shard_num (int): The shard number to use + collaborator_count (int): The number of collaborators in the federation + + Returns: + 2 tuples: (image, label) of the training, validation dataset + """ + origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' + path = get_file('mnist.npz', + origin=origin_folder + 'mnist.npz', + file_hash='731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1') + + with np.load(path) as f: + # get all of mnist + X_train_tot = f['x_train'] + y_train_tot = f['y_train'] + + X_valid_tot = f['x_test'] + y_valid_tot = f['y_test'] + + # create the shards + shard_num = int(shard_num) + X_train = X_train_tot[shard_num::collaborator_count] + y_train = y_train_tot[shard_num::collaborator_count] + + X_valid = X_valid_tot[shard_num::collaborator_count] + y_valid = y_valid_tot[shard_num::collaborator_count] + + return (X_train, y_train), (X_valid, y_valid) + + +def load_mnist_shard(shard_num, collaborator_count, categorical=True, + channels_last=True, **kwargs): + """ + Load the MNIST dataset. + + Args: + shard_num (int): The shard to use from the dataset + collaborator_count (int): The number of collaborators in the federation + categorical (bool): True = convert the labels to one-hot encoded + vectors (Default = True) + channels_last (bool): True = The input images have the channels + last (Default = True) + **kwargs: Additional parameters to pass to the function + + Returns: + list: The input shape + int: The number of classes + numpy.ndarray: The training data + numpy.ndarray: The training labels + numpy.ndarray: The validation data + numpy.ndarray: The validation labels + """ + img_rows, img_cols = 28, 28 + num_classes = 10 + + (X_train, y_train), (X_valid, y_valid) = _load_raw_datashards( + shard_num, collaborator_count + ) + + if channels_last: + X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1) + X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 1) + input_shape = (img_rows, img_cols, 1) + else: + X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols) + X_valid = X_valid.reshape(X_valid.shape[0], 1, img_rows, img_cols) + input_shape = (1, img_rows, img_cols) + + X_train = X_train.astype('float32') + X_valid = X_valid.astype('float32') + X_train /= 255 + X_valid /= 255 + + logger.info(f'MNIST > X_train Shape : {X_train.shape}') + logger.info(f'MNIST > y_train Shape : {y_train.shape}') + logger.info(f'MNIST > Train Samples : {X_train.shape[0]}') + logger.info(f'MNIST > Valid Samples : {X_valid.shape[0]}') + + if categorical: + # convert class vectors to binary class matrices + y_train = one_hot(y_train, num_classes) + y_valid = one_hot(y_valid, num_classes) + + return input_shape, num_classes, X_train, y_train, X_valid, y_valid diff --git a/openfl-workspace/keras/tensorflow/mnist/src/model.py b/openfl-workspace/keras/tensorflow/mnist/src/model.py new file mode 100644 index 0000000000..6c21ca782d --- /dev/null +++ b/openfl-workspace/keras/tensorflow/mnist/src/model.py @@ -0,0 +1,96 @@ +# Copyright (C) 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +import tensorflow as tf +import keras + +class CNNModel(keras.Model): + """ + Custom Keras Model for a Convolutional Neural Network (CNN) with custom training and testing steps. + This model showcase how to define a custom training and testing step for a Keras model with Tensorflow. + Methods + ------- + train_step(data) + Performs a single training step, including forward pass, loss computation, gradient calculation, + and weight updates. Also updates the metrics. + test_step(data) + Performs a single testing step, including forward pass, loss computation, and metric updates. + """ + + def train_step(self, data): + """ + Perform a single training step. + Args: + data (tuple): A tuple containing the input data and labels. If the tuple has three elements, + it should be (x, y, sample_weight). Otherwise, it should be (x, y). + Returns: + dict: A dictionary mapping metric names to their current values. This includes the loss and + any other metrics configured in `compile()`. + Notes: + - The loss function and metrics are configured in the `compile()` method. + - The optimizer is used to apply the computed gradients to the model's trainable variables. + """ + + # Unpack the data. Its structure depends on your model and + # on what you pass to `fit()`. + if len(data) == 3: + x, y, sample_weight = data + else: + sample_weight = None + x, y = data + + with tf.GradientTape() as tape: + y_pred = self(x, training=True) # Forward pass + # Compute the loss value. + # The loss function is configured in `compile()`. + loss = self.compute_loss( + y=y, + y_pred=y_pred, + sample_weight=sample_weight, + ) + + # Compute gradients + trainable_vars = self.trainable_variables + gradients = tape.gradient(loss, trainable_vars) + + # Update weights + self.optimizer.apply(gradients, trainable_vars) + + # Update the metrics. + # Metrics are configured in `compile()`. + for metric in self.metrics: + if metric.name == "loss": + metric.update_state(loss) + else: + metric.update_state(y, y_pred, sample_weight=sample_weight) + + # Return a dict mapping metric names to current value. + # Note that it will include the loss (tracked in self.metrics). + return {m.name: m.result() for m in self.metrics} + + def test_step(self, data): + """ + Perform a single test step. + Args: + data (tuple): A tuple containing the input data (x) and the true labels (y). + Returns: + dict: A dictionary mapping metric names to their current values. This includes the loss and other metrics tracked in self.metrics. + """ + + # Unpack the data + x, y = data + # Compute predictions + y_pred = self(x, training=False) + # Updates the metrics tracking the loss + loss = self.compute_loss(y=y, y_pred=y_pred) + # Update the metrics. + for metric in self.metrics: + if metric.name == "loss": + metric.update_state(loss) + else: + metric.update_state(y, y_pred) + # Return a dict mapping metric names to current value. + # Note that it will include the loss (tracked in self.metrics). + return {m.name: m.result() for m in self.metrics} diff --git a/openfl-workspace/keras/tensorflow/mnist/src/taskrunner.py b/openfl-workspace/keras/tensorflow/mnist/src/taskrunner.py new file mode 100644 index 0000000000..c29cb45cca --- /dev/null +++ b/openfl-workspace/keras/tensorflow/mnist/src/taskrunner.py @@ -0,0 +1,90 @@ +# Copyright (C) 2020-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +"""You may copy this file as the starting point of your own model.""" + +import keras + +from .model import CNNModel +from openfl.federated import KerasTaskRunner + +class CNNTaskruner(KerasTaskRunner): + """A basic convolutional neural network model.""" + + def __init__(self, **kwargs): + """ + Initializes the TaskRunner instance. Builds the Keras model, initializes required tensors for all publicly accessible methods that + could be called as part of a task and initializes the logger. + + Args: + **kwargs: Arbitrary keyword arguments passed to the superclass and used for model building. + + Attributes: + model (keras.Model): The Keras model built using the provided feature shape and number of classes. + logger (logging.Logger): Logger instance for logging information. + + Methods: + build_model: Constructs the Keras model. + initialize_tensorkeys_for_functions: Initializes tensor keys for various functions. + get_train_data_size: Returns the size of the training dataset. + get_valid_data_size: Returns the size of the validation dataset. + """ + super().__init__(**kwargs) + + self.model = self.build_model(self.feature_shape, self.data_loader.num_classes, **kwargs) + + self.initialize_tensorkeys_for_functions() + + self.model.summary(print_fn=self.logger.info) + + self.logger.info(f'Train Set Size : {self.get_train_data_size()}') + self.logger.info(f'Valid Set Size : {self.get_valid_data_size()}') + + def build_model(self, + input_shape, + num_classes, + conv_kernel_size=(4, 4), + conv_strides=(2, 2), + conv1_channels_out=16, + conv2_channels_out=32, + final_dense_inputsize=100, + **kwargs): + """ + Builds and compiles a Convolutional Neural Network (CNN) model. + + Args: + input_shape (tuple): Shape of the input data (height, width, channels). + num_classes (int): Number of output classes. + conv_kernel_size (tuple, optional): Size of the convolutional kernels. Defaults to (4, 4). + conv_strides (tuple, optional): Strides of the convolutional layers. Defaults to (2, 2). + conv1_channels_out (int, optional): Number of output channels for the first convolutional layer. Defaults to 16. + conv2_channels_out (int, optional): Number of output channels for the second convolutional layer. Defaults to 32. + final_dense_inputsize (int, optional): Number of units in the final dense layer before the output layer. Defaults to 100. + **kwargs: Additional keyword arguments. + Returns: + keras.Model: Compiled CNN model. + """ + inputs = keras.Input(shape=input_shape) + outputs = keras.layers.Conv2D(conv1_channels_out, + kernel_size=conv_kernel_size, + strides=conv_strides, + activation='relu', + input_shape=input_shape)(inputs) + outputs = keras.layers.Conv2D(conv2_channels_out, + kernel_size=conv_kernel_size, + strides=conv_strides, + activation='relu')(outputs) + + outputs = keras.layers.Flatten()(outputs) + + outputs = keras.layers.Dense(final_dense_inputsize, activation='relu')(outputs) + + outputs = keras.layers.Dense(num_classes, activation='softmax')(outputs) + + model = CNNModel(inputs, outputs) + + model.compile(loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"]) + + return model