-
Notifications
You must be signed in to change notification settings - Fork 217
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TensorFlow Taskrunner Workspace with Keras 3 (#1330)
* code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * formatting changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * code changes Signed-off-by: yes <[email protected]> * changed dockstring Signed-off-by: yes <[email protected]> --------- Signed-off-by: yes <[email protected]>
- Loading branch information
Showing
11 changed files
with
409 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
current_plan_name: default | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
collaborators: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
# collaborator_name,data_directory_path | ||
one,1 | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
../../workspace/plan/defaults | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. | ||
|
||
aggregator : | ||
defaults : plan/defaults/aggregator.yaml | ||
template : openfl.component.Aggregator | ||
settings : | ||
init_state_path : save/init.pbuf | ||
best_state_path : save/best.pbuf | ||
last_state_path : save/last.pbuf | ||
rounds_to_train : 10 | ||
|
||
collaborator : | ||
defaults : plan/defaults/collaborator.yaml | ||
template : openfl.component.Collaborator | ||
settings : | ||
delta_updates : false | ||
opt_treatment : RESET | ||
|
||
data_loader : | ||
defaults : plan/defaults/data_loader.yaml | ||
template : src.dataloader.MNISTInMemory | ||
settings : | ||
collaborator_count : 2 | ||
data_group_name : mnist | ||
batch_size : 256 | ||
|
||
task_runner : | ||
defaults : plan/defaults/task_runner.yaml | ||
template : src.taskrunner.CNNTaskruner | ||
|
||
network : | ||
defaults : plan/defaults/network.yaml | ||
|
||
assigner : | ||
defaults : plan/defaults/assigner.yaml | ||
|
||
tasks : | ||
defaults : plan/defaults/tasks_keras.yaml | ||
|
||
compression_pipeline : | ||
defaults : plan/defaults/compression_pipeline.yaml |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
keras==3.8.0 | ||
tensorflow==2.18.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
"""You may copy this file as the starting point of your own model.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""You may copy this file as the starting point of your own model.""" | ||
|
||
from openfl.federated import KerasDataLoader | ||
from .mnist_utils import load_mnist_shard | ||
|
||
|
||
class MNISTInMemory(KerasDataLoader): | ||
"""Data Loader for MNIST Dataset.""" | ||
|
||
def __init__(self, data_path, batch_size, **kwargs): | ||
""" | ||
Initialize. | ||
Args: | ||
data_path: File path for the dataset | ||
batch_size (int): The batch size for the data loader | ||
**kwargs: Additional arguments, passed to super init and load_mnist_shard | ||
""" | ||
super().__init__(batch_size, **kwargs) | ||
|
||
try: | ||
int(data_path) | ||
except: | ||
raise ValueError( | ||
"Expected `%s` to be representable as `int`, as it refers to the data shard " + | ||
"number used by the collaborator.", | ||
data_path | ||
) | ||
|
||
_, num_classes, X_train, y_train, X_valid, y_valid = load_mnist_shard( | ||
shard_num=int(data_path), **kwargs | ||
) | ||
|
||
self.X_train = X_train | ||
self.y_train = y_train | ||
self.X_valid = X_valid | ||
self.y_valid = y_valid | ||
|
||
self.num_classes = num_classes |
118 changes: 118 additions & 0 deletions
118
openfl-workspace/keras/tensorflow/mnist/src/mnist_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""You may copy this file as the starting point of your own model.""" | ||
|
||
from logging import getLogger | ||
|
||
import numpy as np | ||
from tensorflow.python.keras.utils.data_utils import get_file | ||
|
||
logger = getLogger(__name__) | ||
|
||
|
||
def one_hot(labels, classes): | ||
""" | ||
One Hot encode a vector. | ||
Args: | ||
labels (list): List of labels to onehot encode | ||
classes (int): Total number of categorical classes | ||
Returns: | ||
np.array: Matrix of one-hot encoded labels | ||
""" | ||
return np.eye(classes)[labels] | ||
|
||
|
||
def _load_raw_datashards(shard_num, collaborator_count): | ||
""" | ||
Load the raw data by shard. | ||
Returns tuples of the dataset shard divided into training and validation. | ||
Args: | ||
shard_num (int): The shard number to use | ||
collaborator_count (int): The number of collaborators in the federation | ||
Returns: | ||
2 tuples: (image, label) of the training, validation dataset | ||
""" | ||
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/' | ||
path = get_file('mnist.npz', | ||
origin=origin_folder + 'mnist.npz', | ||
file_hash='731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1') | ||
|
||
with np.load(path) as f: | ||
# get all of mnist | ||
X_train_tot = f['x_train'] | ||
y_train_tot = f['y_train'] | ||
|
||
X_valid_tot = f['x_test'] | ||
y_valid_tot = f['y_test'] | ||
|
||
# create the shards | ||
shard_num = int(shard_num) | ||
X_train = X_train_tot[shard_num::collaborator_count] | ||
y_train = y_train_tot[shard_num::collaborator_count] | ||
|
||
X_valid = X_valid_tot[shard_num::collaborator_count] | ||
y_valid = y_valid_tot[shard_num::collaborator_count] | ||
|
||
return (X_train, y_train), (X_valid, y_valid) | ||
|
||
|
||
def load_mnist_shard(shard_num, collaborator_count, categorical=True, | ||
channels_last=True, **kwargs): | ||
""" | ||
Load the MNIST dataset. | ||
Args: | ||
shard_num (int): The shard to use from the dataset | ||
collaborator_count (int): The number of collaborators in the federation | ||
categorical (bool): True = convert the labels to one-hot encoded | ||
vectors (Default = True) | ||
channels_last (bool): True = The input images have the channels | ||
last (Default = True) | ||
**kwargs: Additional parameters to pass to the function | ||
Returns: | ||
list: The input shape | ||
int: The number of classes | ||
numpy.ndarray: The training data | ||
numpy.ndarray: The training labels | ||
numpy.ndarray: The validation data | ||
numpy.ndarray: The validation labels | ||
""" | ||
img_rows, img_cols = 28, 28 | ||
num_classes = 10 | ||
|
||
(X_train, y_train), (X_valid, y_valid) = _load_raw_datashards( | ||
shard_num, collaborator_count | ||
) | ||
|
||
if channels_last: | ||
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1) | ||
X_valid = X_valid.reshape(X_valid.shape[0], img_rows, img_cols, 1) | ||
input_shape = (img_rows, img_cols, 1) | ||
else: | ||
X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols) | ||
X_valid = X_valid.reshape(X_valid.shape[0], 1, img_rows, img_cols) | ||
input_shape = (1, img_rows, img_cols) | ||
|
||
X_train = X_train.astype('float32') | ||
X_valid = X_valid.astype('float32') | ||
X_train /= 255 | ||
X_valid /= 255 | ||
|
||
logger.info(f'MNIST > X_train Shape : {X_train.shape}') | ||
logger.info(f'MNIST > y_train Shape : {y_train.shape}') | ||
logger.info(f'MNIST > Train Samples : {X_train.shape[0]}') | ||
logger.info(f'MNIST > Valid Samples : {X_valid.shape[0]}') | ||
|
||
if categorical: | ||
# convert class vectors to binary class matrices | ||
y_train = one_hot(y_train, num_classes) | ||
y_valid = one_hot(y_valid, num_classes) | ||
|
||
return input_shape, num_classes, X_train, y_train, X_valid, y_valid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright (C) 2020-2025 Intel Corporation | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""You may copy this file as the starting point of your own model.""" | ||
|
||
import tensorflow as tf | ||
import keras | ||
|
||
class CNNModel(keras.Model): | ||
""" | ||
Custom Keras Model for a Convolutional Neural Network (CNN) with custom training and testing steps. | ||
This model showcase how to define a custom training and testing step for a Keras model with Tensorflow. | ||
Methods | ||
------- | ||
train_step(data) | ||
Performs a single training step, including forward pass, loss computation, gradient calculation, | ||
and weight updates. Also updates the metrics. | ||
test_step(data) | ||
Performs a single testing step, including forward pass, loss computation, and metric updates. | ||
""" | ||
|
||
def train_step(self, data): | ||
""" | ||
Perform a single training step. | ||
Args: | ||
data (tuple): A tuple containing the input data and labels. If the tuple has three elements, | ||
it should be (x, y, sample_weight). Otherwise, it should be (x, y). | ||
Returns: | ||
dict: A dictionary mapping metric names to their current values. This includes the loss and | ||
any other metrics configured in `compile()`. | ||
Notes: | ||
- The loss function and metrics are configured in the `compile()` method. | ||
- The optimizer is used to apply the computed gradients to the model's trainable variables. | ||
""" | ||
|
||
# Unpack the data. Its structure depends on your model and | ||
# on what you pass to `fit()`. | ||
if len(data) == 3: | ||
x, y, sample_weight = data | ||
else: | ||
sample_weight = None | ||
x, y = data | ||
|
||
with tf.GradientTape() as tape: | ||
y_pred = self(x, training=True) # Forward pass | ||
# Compute the loss value. | ||
# The loss function is configured in `compile()`. | ||
loss = self.compute_loss( | ||
y=y, | ||
y_pred=y_pred, | ||
sample_weight=sample_weight, | ||
) | ||
|
||
# Compute gradients | ||
trainable_vars = self.trainable_variables | ||
gradients = tape.gradient(loss, trainable_vars) | ||
|
||
# Update weights | ||
self.optimizer.apply(gradients, trainable_vars) | ||
|
||
# Update the metrics. | ||
# Metrics are configured in `compile()`. | ||
for metric in self.metrics: | ||
if metric.name == "loss": | ||
metric.update_state(loss) | ||
else: | ||
metric.update_state(y, y_pred, sample_weight=sample_weight) | ||
|
||
# Return a dict mapping metric names to current value. | ||
# Note that it will include the loss (tracked in self.metrics). | ||
return {m.name: m.result() for m in self.metrics} | ||
|
||
def test_step(self, data): | ||
""" | ||
Perform a single test step. | ||
Args: | ||
data (tuple): A tuple containing the input data (x) and the true labels (y). | ||
Returns: | ||
dict: A dictionary mapping metric names to their current values. This includes the loss and other metrics tracked in self.metrics. | ||
""" | ||
|
||
# Unpack the data | ||
x, y = data | ||
# Compute predictions | ||
y_pred = self(x, training=False) | ||
# Updates the metrics tracking the loss | ||
loss = self.compute_loss(y=y, y_pred=y_pred) | ||
# Update the metrics. | ||
for metric in self.metrics: | ||
if metric.name == "loss": | ||
metric.update_state(loss) | ||
else: | ||
metric.update_state(y, y_pred) | ||
# Return a dict mapping metric names to current value. | ||
# Note that it will include the loss (tracked in self.metrics). | ||
return {m.name: m.result() for m in self.metrics} |
Oops, something went wrong.